mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 05:31:40 +00:00
refactor: Refactor datasource module (#1309)
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""RDBMS Connector Module."""
|
||||
|
@@ -1,86 +0,0 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.storage.schema import DBType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class BaseDao:
|
||||
def __init__(
|
||||
self, orm_base=None, database: str = None, create_not_exist_table: bool = False
|
||||
) -> None:
|
||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||
self._orm_base = orm_base
|
||||
self._database = database
|
||||
self._create_not_exist_table = create_not_exist_table
|
||||
|
||||
self._db_engine = None
|
||||
self._session = None
|
||||
self._connection = None
|
||||
|
||||
@property
|
||||
def db_engine(self):
|
||||
if not self._db_engine:
|
||||
# lazy loading
|
||||
db_engine, connection = _get_db_engine(
|
||||
self._orm_base, self._database, self._create_not_exist_table
|
||||
)
|
||||
self._db_engine = db_engine
|
||||
self._connection = connection
|
||||
return self._db_engine
|
||||
|
||||
@property
|
||||
def Session(self):
|
||||
if not self._session:
|
||||
self._session = sessionmaker(bind=self.db_engine)
|
||||
return self._session
|
||||
|
||||
|
||||
def _get_db_engine(
|
||||
orm_base=None, database: str = None, create_not_exist_table: bool = False
|
||||
):
|
||||
db_engine = None
|
||||
connection: RDBMSDatabase = None
|
||||
|
||||
db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
|
||||
if db_type is None or db_type == DBType.Mysql:
|
||||
# default database
|
||||
db_engine = create_engine(
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
else:
|
||||
db_namager = CFG.LOCAL_DB_MANAGE
|
||||
if not db_namager:
|
||||
raise Exception(
|
||||
"LOCAL_DB_MANAGE is not initialized, please check the system configuration"
|
||||
)
|
||||
if db_type.is_file_db():
|
||||
db_path = CFG.LOCAL_DB_PATH
|
||||
if db_path is None or db_path == "":
|
||||
raise ValueError(
|
||||
"You LOCAL_DB_TYPE is file db, but LOCAL_DB_PATH is not configured, please configure LOCAL_DB_PATH in you .env file"
|
||||
)
|
||||
_, database = db_namager._parse_file_db_info(db_type.value(), db_path)
|
||||
logger.info(
|
||||
f"Current DAO database is file database, db_type: {db_type.value()}, db_path: {db_path}, db_name: {database}"
|
||||
)
|
||||
logger.info(f"Get DAO database connection with database name {database}")
|
||||
connection: RDBMSDatabase = db_namager.get_connect(database)
|
||||
if not isinstance(connection, RDBMSDatabase):
|
||||
raise ValueError(
|
||||
"Currently only supports `RDBMSDatabase` database as the underlying database of BaseDao, please check your database configuration"
|
||||
)
|
||||
db_engine = connection._engine
|
||||
|
||||
if db_type.is_file_db() and orm_base is not None and create_not_exist_table:
|
||||
logger.info("Current database is file database, create not exist table")
|
||||
orm_base.metadata.create_all(db_engine)
|
||||
|
||||
return db_engine, connection
|
@@ -1,6 +1,9 @@
|
||||
"""Base class for RDBMS connectors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
@@ -13,11 +16,10 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker
|
||||
from sqlalchemy.schema import CreateTable
|
||||
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
from dbgpt.datasource.base import BaseConnector
|
||||
from dbgpt.storage.schema import DBType
|
||||
|
||||
CFG = Config()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
@@ -27,11 +29,9 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
)
|
||||
|
||||
|
||||
class RDBMSDatabase(BaseConnect):
|
||||
class RDBMSConnector(BaseConnector):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
db_type: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
@@ -41,10 +41,11 @@ class RDBMSDatabase(BaseConnect):
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
custom_table_info: Optional[Dict[str, str]] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Create engine from database URI.
|
||||
|
||||
Args:
|
||||
- engine: Engine sqlalchemy.engine
|
||||
- schema: Optional[str].
|
||||
@@ -61,28 +62,27 @@ class RDBMSDatabase(BaseConnect):
|
||||
if include_tables and ignore_tables:
|
||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||
|
||||
if not custom_table_info:
|
||||
custom_table_info = {}
|
||||
|
||||
self._inspector = inspect(engine)
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
Session_Manages = scoped_session(session_factory)
|
||||
self._db_sessions = Session_Manages
|
||||
self.session = self.get_session()
|
||||
|
||||
self._all_tables = set()
|
||||
self.view_support = False
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = set()
|
||||
self._indexes_in_table_info = set()
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
self.view_support = view_support
|
||||
self._usable_tables: Set[str] = set()
|
||||
self._include_tables: Set[str] = set()
|
||||
self._ignore_tables: Set[str] = set()
|
||||
self._custom_table_info = custom_table_info
|
||||
self._sample_rows_in_table_info = sample_rows_in_table_info
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
|
||||
self._metadata = MetaData()
|
||||
self._metadata = metadata or MetaData()
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
|
||||
self._all_tables = self._sync_tables_from_db()
|
||||
self._all_tables: Set[str] = cast(Set[str], self._sync_tables_from_db())
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
@@ -94,8 +94,9 @@ class RDBMSDatabase(BaseConnect):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
) -> RDBMSConnector:
|
||||
"""Construct a SQLAlchemy engine from uri database.
|
||||
|
||||
Args:
|
||||
host (str): database host.
|
||||
port (int): database port.
|
||||
@@ -112,7 +113,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
) -> RDBMSConnector:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
||||
@@ -123,7 +124,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
return self._engine.dialect.name
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
"""Read table information from database"""
|
||||
"""Read table information from database."""
|
||||
# TODO Use a background thread to refresh periodically
|
||||
|
||||
# SQL will raise error with schema
|
||||
@@ -153,16 +154,25 @@ class RDBMSDatabase(BaseConnect):
|
||||
return self.get_usable_table_names()
|
||||
|
||||
def get_session(self):
|
||||
"""Get session."""
|
||||
session = self._db_sessions()
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
"""Get current database name.
|
||||
|
||||
Returns:
|
||||
str: database name
|
||||
"""
|
||||
return self.session.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
"""Return table simple info."""
|
||||
_sql = f"""
|
||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
|
||||
select concat(table_name, "(" , group_concat(column_name), ")")
|
||||
as schema_info from information_schema.COLUMNS where
|
||||
table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
|
||||
"""
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
@@ -222,12 +232,16 @@ class RDBMSDatabase(BaseConnect):
|
||||
return final_str
|
||||
|
||||
def get_columns(self, table_name: str) -> List[Dict]:
|
||||
"""Get columns.
|
||||
"""Get columns about specified table.
|
||||
|
||||
Args:
|
||||
table_name (str): table name
|
||||
|
||||
Returns:
|
||||
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'int', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
columns: List[Dict], which contains name: str, type: str,
|
||||
default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
|
||||
'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
"""
|
||||
return self._inspector.get_columns(table_name)
|
||||
|
||||
@@ -280,13 +294,14 @@ class RDBMSDatabase(BaseConnect):
|
||||
Args:
|
||||
write_sql (str): SQL write command to run
|
||||
"""
|
||||
print(f"Write[{write_sql}]")
|
||||
logger.info(f"Write[{write_sql}]")
|
||||
db_cache = self._engine.url.database
|
||||
result = self.session.execute(text(write_sql))
|
||||
self.session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
# TODO Subsequent optimization of dynamically specified database submission
|
||||
# loss target problem
|
||||
self.session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
logger.info(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def _query(self, query: str, fetch: str = "all"):
|
||||
@@ -296,9 +311,9 @@ class RDBMSDatabase(BaseConnect):
|
||||
query (str): SQL query to run
|
||||
fetch (str): fetch type
|
||||
"""
|
||||
result = []
|
||||
result: List[Any] = []
|
||||
|
||||
print(f"Query[{query}]")
|
||||
logger.info(f"Query[{query}]")
|
||||
if not query:
|
||||
return result
|
||||
cursor = self.session.execute(text(query))
|
||||
@@ -314,20 +329,28 @@ class RDBMSDatabase(BaseConnect):
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_table_schema(self, table_name):
|
||||
def query_table_schema(self, table_name: str):
|
||||
"""Query table schema.
|
||||
|
||||
Args:
|
||||
table_name (str): table name
|
||||
"""
|
||||
sql = f"select * from {table_name} limit 1"
|
||||
return self._query(sql)
|
||||
|
||||
def query_ex(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
def query_ex(self, query: str, fetch: str = "all"):
|
||||
"""Execute a SQL command and return the results.
|
||||
|
||||
Only for query command.
|
||||
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
query (str): SQL query to run
|
||||
fetch (str): fetch type
|
||||
|
||||
Returns:
|
||||
List: result list
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
logger.info(f"Query[{query}]")
|
||||
if not query:
|
||||
return [], None
|
||||
cursor = self.session.execute(text(query))
|
||||
@@ -338,7 +361,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
result = cursor.fetchone() # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
field_names = list(cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
@@ -346,7 +369,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
logger.info("SQL:" + command)
|
||||
if not command or len(command) < 0:
|
||||
return []
|
||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
@@ -356,11 +379,13 @@ class RDBMSDatabase(BaseConnect):
|
||||
else:
|
||||
self._write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
logger.info(f"write result query:{select_sql}")
|
||||
return self._query(select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
logger.info(
|
||||
"DDL execution determines whether to enable through configuration "
|
||||
)
|
||||
cursor = self.session.execute(text(command))
|
||||
self.session.commit()
|
||||
if cursor.returns_rows:
|
||||
@@ -368,7 +393,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
logger.info("DDL Result:" + str(result))
|
||||
if not result:
|
||||
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.get_simple_fields(table_name)
|
||||
@@ -377,6 +402,7 @@ class RDBMSDatabase(BaseConnect):
|
||||
return self.get_simple_fields(table_name)
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
"""Execute sql command and return result as dataframe."""
|
||||
import pandas as pd
|
||||
|
||||
# Pandas has too much dependence and the import time is too long
|
||||
@@ -398,44 +424,45 @@ class RDBMSDatabase(BaseConnect):
|
||||
return self.run(command, fetch)
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
logger.warning(f"Run SQL command failed: {e}")
|
||||
return []
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
def convert_sql_write_to_select(self, write_sql: str) -> str:
|
||||
"""Convert SQL write command to a SELECT command.
|
||||
|
||||
def convert_sql_write_to_select(self, write_sql):
|
||||
"""
|
||||
SQL classification processing
|
||||
author:xiangh8
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
write_sql = "insert into test(id) values (1)"
|
||||
select_sql = convert_sql_write_to_select(write_sql)
|
||||
print(select_sql)
|
||||
# SELECT * FROM test WHERE id=1
|
||||
Args:
|
||||
sql:
|
||||
write_sql (str): SQL write command
|
||||
|
||||
Returns:
|
||||
|
||||
str: SELECT command corresponding to the write command
|
||||
"""
|
||||
# 将SQL命令转换为小写,并按空格拆分
|
||||
# Convert the SQL command to lowercase and split by space
|
||||
parts = write_sql.lower().split()
|
||||
# 获取命令类型(insert, delete, update)
|
||||
# Get the command type (insert, delete, update)
|
||||
cmd_type = parts[0]
|
||||
|
||||
# 根据命令类型进行处理
|
||||
# Handle according to command type
|
||||
if cmd_type == "insert":
|
||||
match = re.match(
|
||||
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||
r"insert\s+into\s+(\w+)\s*\(([^)]+)\)\s*values\s*\(([^)]+)\)",
|
||||
write_sql.lower(),
|
||||
)
|
||||
if match:
|
||||
# Get the table name, columns, and values
|
||||
table_name, columns, values = match.groups()
|
||||
# 将字段列表和值列表分割为单独的字段和值
|
||||
columns = columns.split(",")
|
||||
values = values.split(",")
|
||||
# 构造 WHERE 子句
|
||||
# Build the WHERE clause
|
||||
where_clause = " AND ".join(
|
||||
[
|
||||
f"{col.strip()}={val.strip()}"
|
||||
@@ -443,21 +470,23 @@ class RDBMSDatabase(BaseConnect):
|
||||
]
|
||||
)
|
||||
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command: {write_sql}")
|
||||
|
||||
elif cmd_type == "delete":
|
||||
table_name = parts[2] # delete from <table_name> ...
|
||||
# 返回一个select语句,它选择该表的所有数据
|
||||
# Return a SELECT statement that selects all data from the table
|
||||
return f"SELECT * FROM {table_name} "
|
||||
|
||||
elif cmd_type == "update":
|
||||
table_name = parts[1]
|
||||
set_idx = parts.index("set")
|
||||
where_idx = parts.index("where")
|
||||
# 截取 `set` 子句中的字段名
|
||||
# Get the field name in the `set` clause
|
||||
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
# Get the condition statement after the `where`
|
||||
where_clause = " ".join(parts[where_idx + 1 :])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
# Return a SELECT statement that selects the updated data
|
||||
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||
@@ -473,7 +502,9 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||
logger.info(
|
||||
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
|
||||
)
|
||||
return parsed, ttype, sql_type, table_name
|
||||
|
||||
def _extract_table_name_from_ddl(self, parsed):
|
||||
@@ -485,8 +516,10 @@ class RDBMSDatabase(BaseConnect):
|
||||
|
||||
def get_indexes(self, table_name: str) -> List[Dict]:
|
||||
"""Get table indexes about specified table.
|
||||
|
||||
Args:
|
||||
table_name:(str) table name
|
||||
|
||||
Returns:
|
||||
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
|
||||
"""
|
||||
@@ -504,9 +537,9 @@ class RDBMSDatabase(BaseConnect):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
|
||||
"COLUMN_COMMENT from information_schema.COLUMNS where "
|
||||
f"table_name='{table_name}'".format(table_name)
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
@@ -516,40 +549,41 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""Get column fields about specified table."""
|
||||
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def get_charset(self):
|
||||
def get_charset(self) -> str:
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@character_set_database"))
|
||||
character_set = cursor.fetchone()[0]
|
||||
cursor = session.execute(text("SELECT @@character_set_database"))
|
||||
character_set = cursor.fetchone()[0] # type: ignore
|
||||
return character_set
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@collation_database"))
|
||||
cursor = session.execute(text("SELECT @@collation_database"))
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grant info."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW GRANTS"))
|
||||
cursor = session.execute(text("SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
try:
|
||||
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
|
||||
cursor = self.session.execute(text("SELECT user, host FROM mysql.user"))
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def get_table_comments(self, db_name: str):
|
||||
"""Return table comments."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""SELECT table_name, table_comment FROM information_schema.tables
|
||||
f"""SELECT table_name, table_comment FROM information_schema.tables
|
||||
WHERE table_schema = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
@@ -570,10 +604,11 @@ class RDBMSDatabase(BaseConnect):
|
||||
"""
|
||||
return self._inspector.get_table_comment(table_name)
|
||||
|
||||
def get_column_comments(self, db_name, table_name):
|
||||
def get_column_comments(self, db_name: str, table_name: str):
|
||||
"""Return column comments."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""SELECT column_name, column_comment FROM information_schema.columns
|
||||
f"""SELECT column_name, column_comment FROM information_schema.columns
|
||||
WHERE table_schema = '{db_name}' and table_name = '{table_name}'
|
||||
""".format(
|
||||
db_name, table_name
|
||||
@@ -585,17 +620,12 @@ class RDBMSDatabase(BaseConnect):
|
||||
(column_comment[0], column_comment[1]) for column_comment in column_comments
|
||||
]
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
def get_database_names(self) -> List[str]:
|
||||
"""Return a list of database names available in the database.
|
||||
|
||||
def get_database_names(self):
|
||||
Returns:
|
||||
List[str]: database list
|
||||
"""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
|
@@ -1,18 +1,20 @@
|
||||
"""Clickhouse connector."""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import MetaData, text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.storage.schema import DBType
|
||||
|
||||
from .base import RDBMSConnector
|
||||
|
||||
class ClickhouseConnect(RDBMSDatabase):
|
||||
"""Connect Clickhouse Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickhouseConnector(RDBMSConnector):
|
||||
"""Clickhouse connector."""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "clickhouse"
|
||||
@@ -24,6 +26,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
client: Any = None
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new ClickhouseConnector from client."""
|
||||
self.client = client
|
||||
|
||||
self._all_tables = set()
|
||||
@@ -49,7 +52,8 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
) -> "ClickhouseConnector":
|
||||
"""Create a new ClickhouseConnector from host, port, user, pwd, db_name."""
|
||||
import clickhouse_connect
|
||||
from clickhouse_connect.driver import httputil
|
||||
|
||||
@@ -70,11 +74,6 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
cls.client = client
|
||||
return cls(client, **kwargs)
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
pass
|
||||
|
||||
def get_table_names(self):
|
||||
"""Get all table names."""
|
||||
session = self.client
|
||||
@@ -85,6 +84,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
def get_indexes(self, table_name: str) -> List[Dict]:
|
||||
"""Get table indexes about specified table.
|
||||
|
||||
Args:
|
||||
table_name (str): table name
|
||||
Returns:
|
||||
@@ -93,10 +93,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
|
||||
SELECT name AS table, primary_key, from system.tables where
|
||||
database ='{self.client.database}' and table = '{table_name}'
|
||||
"""
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
indexes = [block for block in stream]
|
||||
indexes = [block for block in stream] # noqa
|
||||
return [
|
||||
{"name": "primary_key", "column_names": column_names.split(",")}
|
||||
for table, column_names in indexes[0]
|
||||
@@ -104,6 +105,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Get table info."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
@@ -117,7 +119,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
demonstrated in the paper.
|
||||
"""
|
||||
# TODO:
|
||||
pass
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
@@ -133,11 +135,14 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
def get_columns(self, table_name: str) -> List[Dict]:
|
||||
"""Get columns.
|
||||
|
||||
Args:
|
||||
table_name (str): str
|
||||
Returns:
|
||||
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
List[Dict], which contains name: str, type: str,
|
||||
default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
|
||||
'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
"""
|
||||
fields = self.get_fields(table_name)
|
||||
return [
|
||||
@@ -150,18 +155,21 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
|
||||
SELECT name, type, default_expression, is_in_primary_key, comment
|
||||
from system.columns where table='{table_name}'
|
||||
""".format(
|
||||
table_name
|
||||
)
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
fields = [block for block in stream]
|
||||
fields = [block for block in stream] # noqa
|
||||
return fields
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
@@ -169,9 +177,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
session = self.client
|
||||
|
||||
with session.command("SHOW DATABASES") as stream:
|
||||
@@ -184,12 +194,10 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
]
|
||||
return databases
|
||||
|
||||
def get_database_names(self):
|
||||
return self.get_database_list()
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute sql command."""
|
||||
# TODO need to be implemented
|
||||
print("SQL:" + command)
|
||||
logger.info("SQL:" + command)
|
||||
if not command or len(command) < 0:
|
||||
return []
|
||||
_, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
@@ -199,10 +207,12 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
else:
|
||||
self._write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
logger.info(f"write result query:{select_sql}")
|
||||
return self._query(select_sql)
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
logger.info(
|
||||
"DDL execution determines whether to enable through configuration "
|
||||
)
|
||||
|
||||
cursor = self.client.command(command)
|
||||
|
||||
@@ -212,7 +222,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
logger.info("DDL Result:" + str(result))
|
||||
if not result:
|
||||
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.get_simple_fields(table_name)
|
||||
@@ -225,13 +235,16 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def get_current_db_name(self):
|
||||
"""Get current database name."""
|
||||
return self.client.database
|
||||
|
||||
def get_table_comments(self, db_name: str):
|
||||
"""Get table comments."""
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
SELECT table, comment FROM system.tables WHERE database = '{db_name}'
|
||||
""".format(
|
||||
db_name
|
||||
)
|
||||
|
||||
@@ -241,6 +254,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
def get_table_comment(self, table_name: str) -> Dict:
|
||||
"""Get table comment.
|
||||
|
||||
Args:
|
||||
table_name (str): table name
|
||||
Returns:
|
||||
@@ -249,7 +263,9 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
|
||||
SELECT table, comment FROM system.tables WHERE
|
||||
database = '{self.client.database}'and table = '{table_name}'
|
||||
""".format(
|
||||
self.client.database
|
||||
)
|
||||
|
||||
@@ -258,9 +274,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return [{"text": comment} for table_name, comment in table_comments][0]
|
||||
|
||||
def get_column_comments(self, db_name, table_name):
|
||||
"""Get column comments."""
|
||||
session = self.client
|
||||
_query_sql = f"""
|
||||
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
|
||||
select name column, comment from system.columns where database='{db_name}'
|
||||
and table='{table_name}'
|
||||
""".format(
|
||||
db_name, table_name
|
||||
)
|
||||
@@ -270,10 +288,13 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return column_comments
|
||||
|
||||
def table_simple_info(self):
|
||||
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
|
||||
"""Get table simple info."""
|
||||
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray
|
||||
# instead; and quotes need to be escaped
|
||||
|
||||
_sql = f"""
|
||||
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
|
||||
SELECT concat(TABLE_NAME, '(', arrayStringConcat(
|
||||
groupArray(column_name), '-'), ')') AS schema_info
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE table_schema = '{self.get_current_db_name()}'
|
||||
GROUP BY TABLE_NAME
|
||||
@@ -282,18 +303,18 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return [row[0] for block in stream for row in block]
|
||||
|
||||
def _write(self, write_sql: str):
|
||||
"""write data
|
||||
"""Execute write sql.
|
||||
|
||||
Args:
|
||||
write_sql (str): sql string
|
||||
"""
|
||||
# TODO need to be implemented
|
||||
print(f"Write[{write_sql}]")
|
||||
logger.info(f"Write[{write_sql}]")
|
||||
result = self.client.command(write_sql)
|
||||
print(f"SQL[{write_sql}], result:{result.written_rows}")
|
||||
logger.info(f"SQL[{write_sql}], result:{result.written_rows}")
|
||||
|
||||
def _query(self, query: str, fetch: str = "all"):
|
||||
"""Query data from clickhouse
|
||||
"""Query data from clickhouse.
|
||||
|
||||
Args:
|
||||
query (str): sql string
|
||||
@@ -306,7 +327,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
_type_: List<Result>
|
||||
"""
|
||||
# TODO need to be implemented
|
||||
print(f"Query[{query}]")
|
||||
logger.info(f"Query[{query}]")
|
||||
|
||||
if not query:
|
||||
return []
|
||||
@@ -334,11 +355,13 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||
logger.info(
|
||||
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
|
||||
)
|
||||
return parsed, ttype, sql_type, table_name
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
"""Read table information from database"""
|
||||
"""Read table information from database."""
|
||||
# TODO Use a background thread to refresh periodically
|
||||
|
||||
# SQL will raise error with schema
|
||||
|
@@ -1,13 +1,16 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
"""Doris connector."""
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from .base import RDBMSConnector
|
||||
|
||||
|
||||
class DorisConnect(RDBMSDatabase):
|
||||
class DorisConnector(RDBMSConnector):
|
||||
"""Doris connector."""
|
||||
|
||||
driver = "doris"
|
||||
db_type = "doris"
|
||||
db_dialect = "doris"
|
||||
@@ -22,24 +25,27 @@ class DorisConnect(RDBMSDatabase):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
) -> "DorisConnector":
|
||||
"""Create a new DorisConnector from host, port, user, pwd, db_name."""
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
return cast(DorisConnector, cls.from_uri(db_url, engine_args, **kwargs))
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.get_session().execute(
|
||||
text(
|
||||
f"SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=database()"
|
||||
"SELECT TABLE_NAME FROM information_schema.tables where "
|
||||
"TABLE_SCHEMA=database()"
|
||||
)
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
table_results = set(row[0] for row in table_results) # noqa: C401
|
||||
self._all_tables = table_results
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
cursor = self.get_session().execute(text("SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
if len(grants) == 0:
|
||||
@@ -51,14 +57,17 @@ class DorisConnect(RDBMSDatabase):
|
||||
return grants_list
|
||||
|
||||
def _get_current_version(self):
|
||||
"""Get database current version"""
|
||||
"""Get database current version."""
|
||||
return int(
|
||||
self.get_session().execute(text("select current_version()")).scalar()
|
||||
)
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation.
|
||||
ref: https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-reference/Show-Statements/SHOW-COLLATION/
|
||||
|
||||
ref `SHOW COLLATION <https://doris.apache.org/zh-CN/docs/dev/sql-manual/
|
||||
sql-reference/Show-Statements/SHOW-COLLATION/>`_
|
||||
|
||||
"""
|
||||
cursor = self.get_session().execute(text("SHOW COLLATION"))
|
||||
results = cursor.fetchall()
|
||||
@@ -70,11 +79,14 @@ class DorisConnect(RDBMSDatabase):
|
||||
|
||||
def get_columns(self, table_name: str) -> List[Dict]:
|
||||
"""Get columns.
|
||||
|
||||
Args:
|
||||
table_name (str): str
|
||||
Returns:
|
||||
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
columns: List[Dict], which contains name: str, type: str,
|
||||
default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
|
||||
'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
"""
|
||||
fields = self.get_fields(table_name)
|
||||
return [
|
||||
@@ -92,8 +104,8 @@ class DorisConnect(RDBMSDatabase):
|
||||
"""Get column fields about specified table."""
|
||||
cursor = self.get_session().execute(
|
||||
text(
|
||||
f"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT "
|
||||
f"from information_schema.columns "
|
||||
"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
|
||||
"COLUMN_COMMENT from information_schema.columns "
|
||||
f'where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
|
||||
)
|
||||
)
|
||||
@@ -104,7 +116,8 @@ class DorisConnect(RDBMSDatabase):
|
||||
"""Get character_set."""
|
||||
return "utf-8"
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
def get_show_create_table(self, table_name) -> str:
|
||||
"""Get show create table."""
|
||||
# cur = self.get_session().execute(
|
||||
# text(
|
||||
# f"""show create table {table_name}"""
|
||||
@@ -128,6 +141,7 @@ class DorisConnect(RDBMSDatabase):
|
||||
return ""
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
"""Get table comments."""
|
||||
db_name = "database()" if not db_name else f"'{db_name}'"
|
||||
cursor = self.get_session().execute(
|
||||
text(
|
||||
@@ -139,10 +153,8 @@ class DorisConnect(RDBMSDatabase):
|
||||
tables = cursor.fetchall()
|
||||
return [(table[0], table[1]) for table in tables]
|
||||
|
||||
def get_database_list(self):
|
||||
return self.get_database_names()
|
||||
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
cursor = self.get_session().execute(text("SHOW DATABASES"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
@@ -160,15 +172,17 @@ class DorisConnect(RDBMSDatabase):
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
"""Get current database name."""
|
||||
return self.get_session().execute(text("select database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
"""Get table simple info."""
|
||||
cursor = self.get_session().execute(
|
||||
text(
|
||||
f"SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') "
|
||||
f"FROM information_schema.columns "
|
||||
f"where TABLE_SCHEMA=database() "
|
||||
f"GROUP BY TABLE_NAME"
|
||||
"SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') "
|
||||
"FROM information_schema.columns "
|
||||
"where TABLE_SCHEMA=database() "
|
||||
"GROUP BY TABLE_NAME"
|
||||
)
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
|
@@ -1,15 +1,13 @@
|
||||
"""DuckDB connector."""
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from .base import RDBMSConnector
|
||||
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
class DuckDbConnector(RDBMSConnector):
|
||||
"""DuckDB connector."""
|
||||
|
||||
db_type: str = "duckdb"
|
||||
db_dialect: str = "duckdb"
|
||||
@@ -17,21 +15,24 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
) -> RDBMSConnector:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
|
||||
|
||||
def get_users(self):
|
||||
"""Get users."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
|
||||
"SELECT * FROM sqlite_master WHERE type = 'table' AND "
|
||||
"name = 'duckdb_sys_users';"
|
||||
)
|
||||
)
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
@@ -39,12 +40,14 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set of current database."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
def get_table_comments(self, db_name: str):
|
||||
"""Get table comments."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
"""
|
||||
SELECT name, sql FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
)
|
||||
@@ -55,7 +58,8 @@ class DuckDbConnect(RDBMSDatabase):
|
||||
]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
"""Get table simple info."""
|
||||
_tables_sql = """
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
|
@@ -1,14 +1,13 @@
|
||||
from typing import Any, Optional
|
||||
"""Hive Connector."""
|
||||
from typing import Any, Optional, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from .base import RDBMSConnector
|
||||
|
||||
|
||||
class HiveConnect(RDBMSDatabase):
|
||||
"""db type"""
|
||||
class HiveConnector(RDBMSConnector):
|
||||
"""Hive connector."""
|
||||
|
||||
db_type: str = "hive"
|
||||
"""db driver"""
|
||||
@@ -26,28 +25,26 @@ class HiveConnect(RDBMSDatabase):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from uri database.
|
||||
Args:
|
||||
host (str): database host.
|
||||
port (int): database port.
|
||||
user (str): database user.
|
||||
pwd (str): database password.
|
||||
db_name (str): database name.
|
||||
engine_args (Optional[dict]):other engine_args.
|
||||
"""
|
||||
) -> "HiveConnector":
|
||||
"""Create a new HiveConnector from host, port, user, pwd, db_name."""
|
||||
db_url: str = f"{cls.driver}://{host}:{str(port)}/{db_name}"
|
||||
if user and pwd:
|
||||
db_url: str = f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
db_url = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/"
|
||||
f"{db_name}"
|
||||
)
|
||||
return cast(HiveConnector, cls.from_uri(db_url, engine_args, **kwargs))
|
||||
|
||||
def table_simple_info(self):
|
||||
"""Get table simple info."""
|
||||
return []
|
||||
|
||||
def get_users(self):
|
||||
"""Get users."""
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
@@ -55,4 +52,5 @@ class HiveConnect(RDBMSDatabase):
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set of current database."""
|
||||
return "UTF-8"
|
||||
|
@@ -1,17 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Any, Iterable, Optional
|
||||
"""MSSQL connector."""
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
|
||||
from sqlalchemy import text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from .base import RDBMSConnector
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
class MSSQLConnector(RDBMSConnector):
|
||||
"""MSSQL connector."""
|
||||
|
||||
db_type: str = "mssql"
|
||||
db_dialect: str = "mssql"
|
||||
@@ -20,8 +16,10 @@ class MSSQLConnect(RDBMSDatabase):
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
|
||||
"""Get table simple info."""
|
||||
_tables_sql = """
|
||||
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE
|
||||
TABLE_TYPE='BASE TABLE'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
@@ -29,7 +27,8 @@ class MSSQLConnect(RDBMSDatabase):
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}'
|
||||
SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE
|
||||
TABLE_NAME='{table_name}'
|
||||
"""
|
||||
cursor_colums = self.session.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
|
@@ -1,13 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
"""MySQL connector."""
|
||||
|
||||
from .base import RDBMSConnector
|
||||
|
||||
|
||||
class MySQLConnect(RDBMSDatabase):
|
||||
"""Connect MySQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
class MySQLConnector(RDBMSConnector):
|
||||
"""MySQL connector."""
|
||||
|
||||
db_type: str = "mysql"
|
||||
db_dialect: str = "mysql"
|
||||
|
@@ -1,13 +1,19 @@
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
"""PostgreSQL connector."""
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Tuple, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from .base import RDBMSConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostgreSQLDatabase(RDBMSDatabase):
|
||||
class PostgreSQLConnector(RDBMSConnector):
|
||||
"""PostgreSQL connector."""
|
||||
|
||||
driver = "postgresql+psycopg2"
|
||||
db_type = "postgresql"
|
||||
db_dialect = "postgresql"
|
||||
@@ -22,34 +28,38 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
) -> "PostgreSQLConnector":
|
||||
"""Create a new PostgreSQLConnector from host, port, user, pwd, db_name."""
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
return cast(PostgreSQLConnector, cls.from_uri(db_url, engine_args, **kwargs))
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.session.execute(
|
||||
text(
|
||||
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
"SELECT tablename FROM pg_catalog.pg_tables WHERE "
|
||||
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
)
|
||||
)
|
||||
view_results = self.session.execute(
|
||||
text(
|
||||
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
"SELECT viewname FROM pg_catalog.pg_views WHERE "
|
||||
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
)
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
table_results = set(row[0] for row in table_results) # noqa: C401
|
||||
view_results = set(row[0] for row in view_results) # noqa: C401
|
||||
self._all_tables = table_results.union(view_results)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""
|
||||
"""
|
||||
SELECT DISTINCT grantee, privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = CURRENT_USER;"""
|
||||
@@ -64,13 +74,14 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
|
||||
"SELECT datcollate AS collation FROM pg_database WHERE "
|
||||
"datname = current_database();"
|
||||
)
|
||||
)
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
except Exception as e:
|
||||
print("postgresql get collation error: ", e)
|
||||
logger.warning(f"postgresql get collation error: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_users(self):
|
||||
@@ -82,7 +93,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
users = cursor.fetchall()
|
||||
return [user[0] for user in users]
|
||||
except Exception as e:
|
||||
print("postgresql get users error: ", e)
|
||||
logger.warning(f"postgresql get users error: {str(e)}")
|
||||
return []
|
||||
|
||||
def get_fields(self, table_name) -> List[Tuple]:
|
||||
@@ -90,7 +101,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
|
||||
"SELECT column_name, data_type, column_default, is_nullable, "
|
||||
"column_name as column_comment \
|
||||
FROM information_schema.columns WHERE table_name = :table_name",
|
||||
),
|
||||
{"table_name": table_name},
|
||||
@@ -103,23 +115,28 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
|
||||
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE "
|
||||
"datname = current_database();"
|
||||
)
|
||||
)
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
def get_show_create_table(self, table_name: str):
|
||||
"""Return show create table."""
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
|
||||
SELECT a.attname as column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
|
||||
SELECT max(a.attnum)
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
|
||||
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
|
||||
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class
|
||||
WHERE relname='{table_name}')
|
||||
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class
|
||||
WHERE relname='{table_name}')
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -133,6 +150,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
return create_table_query
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
"""Get table comments."""
|
||||
tablses = self.table_simple_info()
|
||||
comments = []
|
||||
for table in tablses:
|
||||
@@ -141,15 +159,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
comments.append((table_name, table_comment))
|
||||
return comments
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
@@ -158,10 +169,12 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
"""Get current database name."""
|
||||
return self.session.execute(text("SELECT current_database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
"""Get table simple info."""
|
||||
_sql = """
|
||||
SELECT table_name, string_agg(column_name, ', ') AS schema_info
|
||||
FROM (
|
||||
SELECT c.relname AS table_name, a.attname AS column_name
|
||||
@@ -181,17 +194,18 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def get_fields(self, table_name, schema_name="public"):
|
||||
def get_fields_wit_schema(self, table_name, schema_name="public"):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
|
||||
FROM information_schema.columns c
|
||||
LEFT JOIN pg_catalog.pg_description d
|
||||
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
|
||||
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
|
||||
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable,
|
||||
d.description FROM information_schema.columns c
|
||||
LEFT JOIN pg_catalog.pg_description d
|
||||
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid
|
||||
AND c.ordinal_position = d.objsubid
|
||||
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
@@ -203,7 +217,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
|
||||
f"SELECT indexname, indexdef FROM pg_indexes WHERE "
|
||||
f"tablename = '{table_name}'"
|
||||
)
|
||||
)
|
||||
indexes = cursor.fetchall()
|
||||
|
@@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""SQLite connector."""
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
@@ -8,16 +6,13 @@ from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from .base import RDBMSConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLiteConnect(RDBMSDatabase):
|
||||
"""Connect SQLite Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
class SQLiteConnector(RDBMSConnector):
|
||||
"""SQLite connector."""
|
||||
|
||||
db_type: str = "sqlite"
|
||||
db_dialect: str = "sqlite"
|
||||
@@ -25,8 +20,8 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
) -> "SQLiteConnector":
|
||||
"""Create a new SQLiteConnector from file path."""
|
||||
_engine_args = engine_args or {}
|
||||
_engine_args["connect_args"] = {"check_same_thread": False}
|
||||
# _engine_args["echo"] = True
|
||||
@@ -52,7 +47,8 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
"""Get table show create table about specified table."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'"
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' "
|
||||
f"AND name='{table_name}'"
|
||||
)
|
||||
)
|
||||
ans = cursor.fetchall()
|
||||
@@ -62,7 +58,7 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
"""Get column fields about specified table."""
|
||||
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
|
||||
fields = cursor.fetchall()
|
||||
print(fields)
|
||||
logger.info(fields)
|
||||
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
||||
|
||||
def get_simple_fields(self, table_name):
|
||||
@@ -70,9 +66,11 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
return self.get_fields(table_name)
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
@@ -80,12 +78,11 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set of current database."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
return []
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
@@ -95,25 +92,27 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
view_results = self.session.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='view'")
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
table_results = set(row[0] for row in table_results) # noqa
|
||||
view_results = set(row[0] for row in view_results) # noqa
|
||||
self._all_tables = table_results.union(view_results)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def _write(self, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
logger.info(f"Write[{write_sql}]")
|
||||
session = self.session
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
# TODO Subsequent optimization of dynamically specified database submission
|
||||
# loss target problem
|
||||
logger.info(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
"""Get table comments."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
"""
|
||||
SELECT name, sql FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
)
|
||||
@@ -124,7 +123,8 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
"""Get table simple info."""
|
||||
_tables_sql = """
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
@@ -146,10 +146,14 @@ class SQLiteConnect(RDBMSDatabase):
|
||||
return results
|
||||
|
||||
|
||||
class SQLiteTempConnect(SQLiteConnect):
|
||||
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
|
||||
class SQLiteTempConnector(SQLiteConnector):
|
||||
"""A temporary SQLite database connection.
|
||||
|
||||
The database file will be deleted when the connection is closed.
|
||||
"""
|
||||
|
||||
def __init__(self, engine, temp_file_path, *args, **kwargs):
|
||||
"""Construct a temporary SQLite database connection."""
|
||||
super().__init__(engine, *args, **kwargs)
|
||||
self.temp_file_path = temp_file_path
|
||||
self._is_closed = False
|
||||
@@ -157,7 +161,7 @@ class SQLiteTempConnect(SQLiteConnect):
|
||||
@classmethod
|
||||
def create_temporary_db(
|
||||
cls, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> "SQLiteTempConnect":
|
||||
) -> "SQLiteTempConnector":
|
||||
"""Create a temporary SQLite database with a temporary file.
|
||||
|
||||
Examples:
|
||||
@@ -175,7 +179,7 @@ class SQLiteTempConnect(SQLiteConnect):
|
||||
engine_args (Optional[dict]): SQLAlchemy engine arguments.
|
||||
|
||||
Returns:
|
||||
SQLiteTempConnect: A SQLiteTempConnect instance.
|
||||
SQLiteTempConnector: A SQLiteTempConnect instance.
|
||||
"""
|
||||
_engine_args = engine_args or {}
|
||||
_engine_args["connect_args"] = {"check_same_thread": False}
|
||||
@@ -219,7 +223,7 @@ class SQLiteTempConnect(SQLiteConnect):
|
||||
],
|
||||
},
|
||||
}
|
||||
with SQLiteTempConnect.create_temporary_db() as db:
|
||||
with SQLiteTempConnector.create_temporary_db() as db:
|
||||
db.create_temp_tables(tables_info)
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id", "name", "age"]
|
||||
@@ -248,14 +252,18 @@ class SQLiteTempConnect(SQLiteConnect):
|
||||
self._sync_tables_from_db()
|
||||
|
||||
def __enter__(self):
|
||||
"""Return the connection when entering the context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Close the connection when exiting the context manager."""
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
"""Close the connection when the object is deleted."""
|
||||
self.close()
|
||||
|
||||
@classmethod
|
||||
def is_normal_type(cls) -> bool:
|
||||
"""Return whether the connector is a normal type."""
|
||||
return False
|
||||
|
@@ -1,21 +1,24 @@
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
"""StarRocks connector."""
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, cast
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import *
|
||||
from .base import RDBMSConnector
|
||||
from .dialect.starrocks.sqlalchemy import * # noqa
|
||||
|
||||
|
||||
class StarRocksConnect(RDBMSDatabase):
|
||||
class StarRocksConnector(RDBMSConnector):
|
||||
"""StarRocks connector."""
|
||||
|
||||
driver = "starrocks"
|
||||
db_type = "starrocks"
|
||||
db_dialect = "starrocks"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
cls: Type["StarRocksConnector"],
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
@@ -23,27 +26,31 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
) -> "StarRocksConnector":
|
||||
"""Create a new StarRocksConnector from host, port, user, pwd, db_name."""
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
return cast(StarRocksConnector, cls.from_uri(db_url, engine_args, **kwargs))
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
db_name = self.get_current_db_name()
|
||||
table_results = self.session.execute(
|
||||
text(
|
||||
f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
|
||||
"SELECT TABLE_NAME FROM information_schema.tables where "
|
||||
f'TABLE_SCHEMA="{db_name}"'
|
||||
)
|
||||
)
|
||||
# view_results = self.session.execute(text(f'SELECT TABLE_NAME from information_schema.materialized_views where TABLE_SCHEMA="{db_name}"'))
|
||||
table_results = set(row[0] for row in table_results)
|
||||
# view_results = self.session.execute(text(f'SELECT TABLE_NAME from
|
||||
# information_schema.materialized_views where TABLE_SCHEMA="{db_name}"'))
|
||||
table_results = set(row[0] for row in table_results) # noqa: C401
|
||||
# view_results = set(row[0] for row in view_results)
|
||||
self._all_tables = table_results
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
@@ -56,7 +63,7 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
return grants_list
|
||||
|
||||
def _get_current_version(self):
|
||||
"""Get database current version"""
|
||||
"""Get database current version."""
|
||||
return int(self.session.execute(text("select current_version()")).scalar())
|
||||
|
||||
def get_collation(self):
|
||||
@@ -75,7 +82,9 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
db_name = f'"{db_name}"'
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f'select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.columns where TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}'
|
||||
"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
|
||||
"COLUMN_COMMENT from information_schema.columns where "
|
||||
f'TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}'
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
@@ -83,10 +92,10 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
|
||||
return "utf-8"
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
def get_show_create_table(self, table_name: str):
|
||||
"""Get show create table."""
|
||||
# cur = self.session.execute(
|
||||
# text(
|
||||
# f"""show create table {table_name}"""
|
||||
@@ -99,7 +108,8 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
# 这里是要表描述, 返回建表语句会导致token过长而失败
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
|
||||
"SELECT TABLE_COMMENT FROM information_schema.tables where "
|
||||
f'TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
|
||||
)
|
||||
)
|
||||
table = cur.fetchone()
|
||||
@@ -109,20 +119,20 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
return ""
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
"""Get table comments."""
|
||||
if not db_name:
|
||||
db_name = self.get_current_db_name()
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
|
||||
"SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables "
|
||||
f'where TABLE_SCHEMA="{db_name}"'
|
||||
)
|
||||
)
|
||||
tables = cur.fetchall()
|
||||
return [(table[0], table[1]) for table in tables]
|
||||
|
||||
def get_database_list(self):
|
||||
return self.get_database_names()
|
||||
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SHOW DATABASES;"))
|
||||
results = cursor.fetchall()
|
||||
@@ -133,11 +143,14 @@ class StarRocksConnect(RDBMSDatabase):
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
"""Get current database name."""
|
||||
return self.session.execute(text("select database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");") FROM information_schema.columns where TABLE_SCHEMA=database()
|
||||
"""Get table simple info."""
|
||||
_sql = """
|
||||
SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");")
|
||||
FROM information_schema.columns where TABLE_SCHEMA=database()
|
||||
GROUP BY TABLE_NAME
|
||||
"""
|
||||
cursor = self.session.execute(text(_sql))
|
||||
|
@@ -0,0 +1 @@
|
||||
"""Module for RDBMS dialects."""
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/python3
|
||||
"""StarRocks dialect for SQLAlchemy."""
|
||||
# Copyright 2021-present StarRocks, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/python3
|
||||
"""SQLAlchemy dialect for StarRocks."""
|
||||
# Copyright 2021-present StarRocks, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@@ -1,3 +1,5 @@
|
||||
"""SQLAlchemy data types for StarRocks."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
@@ -10,50 +12,71 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TINYINT(Integer): # pylint: disable=no-init
|
||||
"""StarRocks TINYINT type."""
|
||||
|
||||
__visit_name__ = "TINYINT"
|
||||
|
||||
|
||||
class LARGEINT(Integer): # pylint: disable=no-init
|
||||
"""StarRocks LARGEINT type."""
|
||||
|
||||
__visit_name__ = "LARGEINT"
|
||||
|
||||
|
||||
class DOUBLE(Float): # pylint: disable=no-init
|
||||
"""StarRocks DOUBLE type."""
|
||||
|
||||
__visit_name__ = "DOUBLE"
|
||||
|
||||
|
||||
class HLL(Numeric): # pylint: disable=no-init
|
||||
"""StarRocks HLL type."""
|
||||
|
||||
__visit_name__ = "HLL"
|
||||
|
||||
|
||||
class BITMAP(Numeric): # pylint: disable=no-init
|
||||
"""StarRocks BITMAP type."""
|
||||
|
||||
__visit_name__ = "BITMAP"
|
||||
|
||||
|
||||
class PERCENTILE(Numeric): # pylint: disable=no-init
|
||||
"""StarRocks PERCENTILE type."""
|
||||
|
||||
__visit_name__ = "PERCENTILE"
|
||||
|
||||
|
||||
class ARRAY(TypeEngine): # pylint: disable=no-init
|
||||
"""StarRocks ARRAY type."""
|
||||
|
||||
__visit_name__ = "ARRAY"
|
||||
|
||||
@property
|
||||
def python_type(self) -> Optional[Type[List[Any]]]:
|
||||
def python_type(self) -> Optional[Type[List[Any]]]: # type: ignore
|
||||
"""Return the Python type for this SQL type."""
|
||||
return list
|
||||
|
||||
|
||||
class MAP(TypeEngine): # pylint: disable=no-init
|
||||
"""StarRocks MAP type."""
|
||||
|
||||
__visit_name__ = "MAP"
|
||||
|
||||
@property
|
||||
def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
|
||||
def python_type(self) -> Optional[Type[Dict[Any, Any]]]: # type: ignore
|
||||
"""Return the Python type for this SQL type."""
|
||||
return dict
|
||||
|
||||
|
||||
class STRUCT(TypeEngine): # pylint: disable=no-init
|
||||
"""StarRocks STRUCT type."""
|
||||
|
||||
__visit_name__ = "STRUCT"
|
||||
|
||||
@property
|
||||
def python_type(self) -> Optional[Type[Any]]:
|
||||
def python_type(self) -> Optional[Type[Any]]: # type: ignore
|
||||
"""Return the Python type for this SQL type."""
|
||||
return None
|
||||
|
||||
|
||||
@@ -90,6 +113,7 @@ _type_map = {
|
||||
|
||||
|
||||
def parse_sqltype(type_str: str) -> TypeEngine:
|
||||
"""Parse a SQL type string into a SQLAlchemy type object."""
|
||||
type_str = type_str.strip().lower()
|
||||
match = re.match(r"^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?", type_str)
|
||||
if not match:
|
||||
|
@@ -1,4 +1,4 @@
|
||||
#! /usr/bin/python3
|
||||
"""StarRocks dialect for SQLAlchemy."""
|
||||
# Copyright 2021-present StarRocks, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from sqlalchemy import exc, log, text
|
||||
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
|
||||
@@ -25,7 +25,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class StarRocksDialect(MySQLDialect_pymysql):
|
||||
class StarRocksDialect(MySQLDialect_pymysql): # type: ignore
|
||||
"""StarRocks dialect for SQLAlchemy."""
|
||||
|
||||
# Caching
|
||||
# Warnings are generated by SQLAlchmey if this flag is not explicitly set
|
||||
# and tests are needed before being enabled
|
||||
@@ -34,9 +36,11 @@ class StarRocksDialect(MySQLDialect_pymysql):
|
||||
name = "starrocks"
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
"""Create a new StarRocks dialect."""
|
||||
super(StarRocksDialect, self).__init__(*args, **kw)
|
||||
|
||||
def has_table(self, connection, table_name, schema=None, **kw):
|
||||
def has_table(self, connection, table_name, schema: Optional[str] = None, **kw):
|
||||
"""Return True if the given table is present in the database."""
|
||||
self._ensure_has_table_connection(connection)
|
||||
|
||||
if schema is None:
|
||||
@@ -53,15 +57,13 @@ class StarRocksDialect(MySQLDialect_pymysql):
|
||||
return res.first() is not None
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
"""Return a list of schema names available in the database."""
|
||||
rp = connection.exec_driver_sql("SHOW schemas")
|
||||
return [r[0] for r in rp]
|
||||
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
def get_table_names(self, connection, schema: Optional[str] = None, **kw):
|
||||
"""Return a Unicode SHOW TABLES from a given schema."""
|
||||
if schema is not None:
|
||||
current_schema = schema
|
||||
else:
|
||||
current_schema = self.default_schema_name
|
||||
current_schema: str = cast(str, schema or self.default_schema_name)
|
||||
|
||||
charset = self._connection_charset
|
||||
|
||||
@@ -76,13 +78,15 @@ class StarRocksDialect(MySQLDialect_pymysql):
|
||||
if row[1] == "BASE TABLE"
|
||||
]
|
||||
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
def get_view_names(self, connection, schema: Optional[str] = None, **kw):
|
||||
"""Return a Unicode SHOW TABLES from a given schema."""
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
current_schema = cast(str, schema)
|
||||
charset = self._connection_charset
|
||||
rp = connection.exec_driver_sql(
|
||||
"SHOW FULL TABLES FROM %s"
|
||||
% self.identifier_preparer.quote_identifier(schema)
|
||||
% self.identifier_preparer.quote_identifier(current_schema)
|
||||
)
|
||||
return [
|
||||
row[0]
|
||||
@@ -90,9 +94,14 @@ class StarRocksDialect(MySQLDialect_pymysql):
|
||||
if row[1] in ("VIEW", "SYSTEM VIEW")
|
||||
]
|
||||
|
||||
def get_columns(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_columns( # type: ignore
|
||||
self,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> List[Dict[str, Any]]: # type: ignore
|
||||
"""Return information about columns in `table_name`."""
|
||||
if not self.has_table(connection, table_name, schema):
|
||||
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
|
||||
schema = schema or self._get_default_schema_name(connection)
|
||||
@@ -114,60 +123,100 @@ class StarRocksDialect(MySQLDialect_pymysql):
|
||||
columns.append(column)
|
||||
return columns
|
||||
|
||||
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
||||
def get_pk_constraint(
|
||||
self, connection, table_name, schema: Optional[str] = None, **kw
|
||||
):
|
||||
"""Return information about the primary key constraint."""
|
||||
return { # type: ignore # pep-655 not supported
|
||||
"name": None,
|
||||
"constrained_columns": [],
|
||||
}
|
||||
|
||||
def get_unique_constraints(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
def get_unique_constraints( # type: ignore
|
||||
self,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return information about unique constraints."""
|
||||
return []
|
||||
|
||||
def get_check_constraints(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
def get_check_constraints( # type: ignore
|
||||
self,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return information about check constraints."""
|
||||
return []
|
||||
|
||||
def get_foreign_keys(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
def get_foreign_keys( # type: ignore
|
||||
self,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return information about foreign keys."""
|
||||
return []
|
||||
|
||||
def get_primary_keys(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
self,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> List[str]:
|
||||
"""Return the primary key columns of the given table."""
|
||||
pk = self.get_pk_constraint(connection, table_name, schema)
|
||||
return pk.get("constrained_columns") # type: ignore
|
||||
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
def get_indexes(self, connection, table_name, schema: Optional[str] = None, **kw):
|
||||
"""Get table indexes about specified table."""
|
||||
return []
|
||||
|
||||
def has_sequence(
|
||||
self, connection: Connection, sequence_name: str, schema: str = None, **kw
|
||||
self,
|
||||
connection: Connection,
|
||||
sequence_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> bool:
|
||||
"""Return True if the given sequence is present in the database."""
|
||||
return False
|
||||
|
||||
def get_sequence_names(
|
||||
self, connection: Connection, schema: str = None, **kw
|
||||
self, connection: Connection, schema: Optional[str] = None, **kw
|
||||
) -> List[str]:
|
||||
"""Return a list of sequence names."""
|
||||
return []
|
||||
|
||||
def get_temp_view_names(
|
||||
self, connection: Connection, schema: str = None, **kw
|
||||
self, connection: Connection, schema: Optional[str] = None, **kw
|
||||
) -> List[str]:
|
||||
"""Return a list of temporary view names."""
|
||||
return []
|
||||
|
||||
def get_temp_table_names(
|
||||
self, connection: Connection, schema: str = None, **kw
|
||||
self, connection: Connection, schema: Optional[str] = None, **kw
|
||||
) -> List[str]:
|
||||
"""Return a list of temporary table names."""
|
||||
return []
|
||||
|
||||
def get_table_options(self, connection, table_name, schema=None, **kw):
|
||||
def get_table_options(
|
||||
self, connection, table_name, schema: Optional[str] = None, **kw
|
||||
):
|
||||
"""Return a dictionary of options specified when the table was created."""
|
||||
return {}
|
||||
|
||||
def get_table_comment(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
def get_table_comment( # type: ignore
|
||||
self,
|
||||
connection: Connection,
|
||||
table_name: str,
|
||||
schema: Optional[str] = None,
|
||||
**kw,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return the comment for a table."""
|
||||
return dict(text=None)
|
||||
|
@@ -6,14 +6,14 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
|
||||
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_db_file.close()
|
||||
conn = DuckDbConnect.from_file_path(temp_db_file.name + "duckdb.db")
|
||||
conn = DuckDbConnector.from_file_path(temp_db_file.name + "duckdb.db")
|
||||
yield conn
|
||||
|
||||
|
||||
|
@@ -6,14 +6,14 @@ import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_db_file.close()
|
||||
conn = SQLiteConnect.from_file_path(temp_db_file.name)
|
||||
conn = SQLiteConnector.from_file_path(temp_db_file.name)
|
||||
yield conn
|
||||
try:
|
||||
# TODO: Failed on windows
|
||||
@@ -43,7 +43,7 @@ def test_run_sql(db):
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw("this is a error sql").startswith("Error:")
|
||||
assert db.run_no_throw("this is a error sql") == []
|
||||
|
||||
|
||||
def test_get_indexes(db):
|
||||
@@ -122,10 +122,6 @@ def test_get_table_comments(db):
|
||||
]
|
||||
|
||||
|
||||
def test_get_database_list(db):
|
||||
db.get_database_list() == []
|
||||
|
||||
|
||||
def test_get_database_names(db):
|
||||
db.get_database_names() == []
|
||||
|
||||
@@ -134,11 +130,11 @@ def test_db_dir_exist_dir():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
new_dir = os.path.join(temp_dir, "new_dir")
|
||||
file_path = os.path.join(new_dir, "sqlite.db")
|
||||
db = SQLiteConnect.from_file_path(file_path)
|
||||
db = SQLiteConnector.from_file_path(file_path)
|
||||
assert os.path.exists(new_dir) == True
|
||||
assert list(db.get_table_names()) == []
|
||||
with tempfile.TemporaryDirectory() as existing_dir:
|
||||
file_path = os.path.join(existing_dir, "sqlite.db")
|
||||
db = SQLiteConnect.from_file_path(file_path)
|
||||
db = SQLiteConnector.from_file_path(file_path)
|
||||
assert os.path.exists(existing_dir) == True
|
||||
assert list(db.get_table_names()) == []
|
||||
|
Reference in New Issue
Block a user