refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -0,0 +1 @@
"""RDBMS Connector Module."""

View File

@@ -1,86 +0,0 @@
import logging
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from dbgpt._private.config import Config
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.storage.schema import DBType
logger = logging.getLogger(__name__)
CFG = Config()
class BaseDao:
def __init__(
self, orm_base=None, database: str = None, create_not_exist_table: bool = False
) -> None:
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
self._orm_base = orm_base
self._database = database
self._create_not_exist_table = create_not_exist_table
self._db_engine = None
self._session = None
self._connection = None
@property
def db_engine(self):
if not self._db_engine:
# lazy loading
db_engine, connection = _get_db_engine(
self._orm_base, self._database, self._create_not_exist_table
)
self._db_engine = db_engine
self._connection = connection
return self._db_engine
@property
def Session(self):
if not self._session:
self._session = sessionmaker(bind=self.db_engine)
return self._session
def _get_db_engine(
orm_base=None, database: str = None, create_not_exist_table: bool = False
):
db_engine = None
connection: RDBMSDatabase = None
db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
if db_type is None or db_type == DBType.Mysql:
# default database
db_engine = create_engine(
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
echo=True,
)
else:
db_namager = CFG.LOCAL_DB_MANAGE
if not db_namager:
raise Exception(
"LOCAL_DB_MANAGE is not initialized, please check the system configuration"
)
if db_type.is_file_db():
db_path = CFG.LOCAL_DB_PATH
if db_path is None or db_path == "":
raise ValueError(
"You LOCAL_DB_TYPE is file db, but LOCAL_DB_PATH is not configured, please configure LOCAL_DB_PATH in you .env file"
)
_, database = db_namager._parse_file_db_info(db_type.value(), db_path)
logger.info(
f"Current DAO database is file database, db_type: {db_type.value()}, db_path: {db_path}, db_name: {database}"
)
logger.info(f"Get DAO database connection with database name {database}")
connection: RDBMSDatabase = db_namager.get_connect(database)
if not isinstance(connection, RDBMSDatabase):
raise ValueError(
"Currently only supports `RDBMSDatabase` database as the underlying database of BaseDao, please check your database configuration"
)
db_engine = connection._engine
if db_type.is_file_db() and orm_base is not None and create_not_exist_table:
logger.info("Current database is file database, create not exist table")
orm_base.metadata.create_all(db_engine)
return db_engine, connection

View File

@@ -1,6 +1,9 @@
"""Base class for RDBMS connectors."""
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional, Tuple
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
@@ -13,11 +16,10 @@ from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.schema import CreateTable
from dbgpt._private.config import Config
from dbgpt.datasource.base import BaseConnect
from dbgpt.datasource.base import BaseConnector
from dbgpt.storage.schema import DBType
CFG = Config()
logger = logging.getLogger(__name__)
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
@@ -27,11 +29,9 @@ def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
)
class RDBMSDatabase(BaseConnect):
class RDBMSConnector(BaseConnector):
"""SQLAlchemy wrapper around a database."""
db_type: str = None
def __init__(
self,
engine,
@@ -41,10 +41,11 @@ class RDBMSDatabase(BaseConnect):
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
custom_table_info: Optional[Dict[str, str]] = None,
view_support: bool = False,
):
"""Create engine from database URI.
Args:
- engine: Engine sqlalchemy.engine
- schema: Optional[str].
@@ -61,28 +62,27 @@ class RDBMSDatabase(BaseConnect):
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
if not custom_table_info:
custom_table_info = {}
self._inspector = inspect(engine)
session_factory = sessionmaker(bind=engine)
Session_Manages = scoped_session(session_factory)
self._db_sessions = Session_Manages
self.session = self.get_session()
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
self.view_support = view_support
self._usable_tables: Set[str] = set()
self._include_tables: Set[str] = set()
self._ignore_tables: Set[str] = set()
self._custom_table_info = custom_table_info
self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
self._metadata = MetaData()
self._metadata = metadata or MetaData()
self._metadata.reflect(bind=self._engine)
self._all_tables = self._sync_tables_from_db()
self._all_tables: Set[str] = cast(Set[str], self._sync_tables_from_db())
@classmethod
def from_uri_db(
@@ -94,8 +94,9 @@ class RDBMSDatabase(BaseConnect):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> RDBMSConnector:
"""Construct a SQLAlchemy engine from uri database.
Args:
host (str): database host.
port (int): database port.
@@ -112,7 +113,7 @@ class RDBMSDatabase(BaseConnect):
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
) -> RDBMSConnector:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@@ -123,7 +124,7 @@ class RDBMSDatabase(BaseConnect):
return self._engine.dialect.name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
"""Read table information from database."""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema
@@ -153,16 +154,25 @@ class RDBMSDatabase(BaseConnect):
return self.get_usable_table_names()
def get_session(self):
"""Get session."""
session = self._db_sessions()
return session
def get_current_db_name(self) -> str:
"""Get current database name.
Returns:
str: database name
"""
return self.session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self):
"""Return table simple info."""
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
select concat(table_name, "(" , group_concat(column_name), ")")
as schema_info from information_schema.COLUMNS where
table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
"""
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
@@ -222,12 +232,16 @@ class RDBMSDatabase(BaseConnect):
return final_str
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns.
"""Get columns about specified table.
Args:
table_name (str): table name
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
return self._inspector.get_columns(table_name)
@@ -280,13 +294,14 @@ class RDBMSDatabase(BaseConnect):
Args:
write_sql (str): SQL write command to run
"""
print(f"Write[{write_sql}]")
logger.info(f"Write[{write_sql}]")
db_cache = self._engine.url.database
result = self.session.execute(text(write_sql))
self.session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
# TODO Subsequent optimization of dynamically specified database submission
# loss target problem
self.session.execute(text(f"use `{db_cache}`"))
print(f"SQL[{write_sql}], result:{result.rowcount}")
logger.info(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def _query(self, query: str, fetch: str = "all"):
@@ -296,9 +311,9 @@ class RDBMSDatabase(BaseConnect):
query (str): SQL query to run
fetch (str): fetch type
"""
result = []
result: List[Any] = []
print(f"Query[{query}]")
logger.info(f"Query[{query}]")
if not query:
return result
cursor = self.session.execute(text(query))
@@ -314,20 +329,28 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
return result
def query_table_schema(self, table_name):
def query_table_schema(self, table_name: str):
"""Query table schema.
Args:
table_name (str): table name
"""
sql = f"select * from {table_name} limit 1"
return self._query(sql)
def query_ex(self, query, fetch: str = "all"):
"""
only for query
def query_ex(self, query: str, fetch: str = "all"):
"""Execute a SQL command and return the results.
Only for query command.
Args:
session:
query:
fetch:
query (str): SQL query to run
fetch (str): fetch type
Returns:
List: result list
"""
print(f"Query[{query}]")
logger.info(f"Query[{query}]")
if not query:
return [], None
cursor = self.session.execute(text(query))
@@ -338,7 +361,7 @@ class RDBMSDatabase(BaseConnect):
result = cursor.fetchone() # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
field_names = list(cursor.keys())
result = list(result)
return field_names, result
@@ -346,7 +369,7 @@ class RDBMSDatabase(BaseConnect):
def run(self, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
logger.info("SQL:" + command)
if not command or len(command) < 0:
return []
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
@@ -356,11 +379,13 @@ class RDBMSDatabase(BaseConnect):
else:
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
logger.info(f"write result query:{select_sql}")
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
logger.info(
"DDL execution determines whether to enable through configuration "
)
cursor = self.session.execute(text(command))
self.session.commit()
if cursor.returns_rows:
@@ -368,7 +393,7 @@ class RDBMSDatabase(BaseConnect):
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
logger.info("DDL Result:" + str(result))
if not result:
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
@@ -377,6 +402,7 @@ class RDBMSDatabase(BaseConnect):
return self.get_simple_fields(table_name)
def run_to_df(self, command: str, fetch: str = "all"):
"""Execute sql command and return result as dataframe."""
import pandas as pd
# Pandas has too much dependence and the import time is too long
@@ -398,44 +424,45 @@ class RDBMSDatabase(BaseConnect):
return self.run(command, fetch)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
logger.warning(f"Run SQL command failed: {e}")
return []
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]
def convert_sql_write_to_select(self, write_sql: str) -> str:
"""Convert SQL write command to a SELECT command.
def convert_sql_write_to_select(self, write_sql):
"""
SQL classification processing
author:xiangh8
Examples:
.. code-block:: python
write_sql = "insert into test(id) values (1)"
select_sql = convert_sql_write_to_select(write_sql)
print(select_sql)
# SELECT * FROM test WHERE id=1
Args:
sql:
write_sql (str): SQL write command
Returns:
str: SELECT command corresponding to the write command
"""
# 将SQL命令转换为小写并按空格拆分
# Convert the SQL command to lowercase and split by space
parts = write_sql.lower().split()
# 获取命令类型(insert, delete, update
# Get the command type (insert, delete, update)
cmd_type = parts[0]
# 根据命令类型进行处理
# Handle according to command type
if cmd_type == "insert":
match = re.match(
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
r"insert\s+into\s+(\w+)\s*\(([^)]+)\)\s*values\s*\(([^)]+)\)",
write_sql.lower(),
)
if match:
# Get the table name, columns, and values
table_name, columns, values = match.groups()
# 将字段列表和值列表分割为单独的字段和值
columns = columns.split(",")
values = values.split(",")
# 构造 WHERE 子句
# Build the WHERE clause
where_clause = " AND ".join(
[
f"{col.strip()}={val.strip()}"
@@ -443,21 +470,23 @@ class RDBMSDatabase(BaseConnect):
]
)
return f"SELECT * FROM {table_name} WHERE {where_clause}"
else:
raise ValueError(f"Unsupported SQL command: {write_sql}")
elif cmd_type == "delete":
table_name = parts[2] # delete from <table_name> ...
# 返回一个select语句它选择该表的所有数据
# Return a SELECT statement that selects all data from the table
return f"SELECT * FROM {table_name} "
elif cmd_type == "update":
table_name = parts[1]
set_idx = parts.index("set")
where_idx = parts.index("where")
# 截取 `set` 子句中的字段名
# Get the field name in the `set` clause
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
# 截取 `where` 之后的条件语句
# Get the condition statement after the `where`
where_clause = " ".join(parts[where_idx + 1 :])
# 返回一个select语句它选择更新的数据
# Return a SELECT statement that selects the updated data
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
else:
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
@@ -473,7 +502,9 @@ class RDBMSDatabase(BaseConnect):
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
logger.info(
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
)
return parsed, ttype, sql_type, table_name
def _extract_table_name_from_ddl(self, parsed):
@@ -485,8 +516,10 @@ class RDBMSDatabase(BaseConnect):
def get_indexes(self, table_name: str) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name:(str) table name
Returns:
List[Dict]:eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
@@ -504,9 +537,9 @@ class RDBMSDatabase(BaseConnect):
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
table_name
)
"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
"COLUMN_COMMENT from information_schema.COLUMNS where "
f"table_name='{table_name}'".format(table_name)
)
)
fields = cursor.fetchall()
@@ -516,40 +549,41 @@ class RDBMSDatabase(BaseConnect):
"""Get column fields about specified table."""
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_charset(self):
def get_charset(self) -> str:
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@character_set_database"))
character_set = cursor.fetchone()[0]
cursor = session.execute(text("SELECT @@character_set_database"))
character_set = cursor.fetchone()[0] # type: ignore
return character_set
def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT @@collation_database"))
cursor = session.execute(text("SELECT @@collation_database"))
collation = cursor.fetchone()[0]
return collation
def get_grants(self):
"""Get grant info."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW GRANTS"))
cursor = session.execute(text("SHOW GRANTS"))
grants = cursor.fetchall()
return grants
def get_users(self):
"""Get user info."""
try:
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
cursor = self.session.execute(text("SELECT user, host FROM mysql.user"))
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
except Exception as e:
except Exception:
return []
def get_table_comments(self, db_name: str):
"""Return table comments."""
cursor = self.session.execute(
text(
f"""SELECT table_name, table_comment FROM information_schema.tables
f"""SELECT table_name, table_comment FROM information_schema.tables
WHERE table_schema = '{db_name}'""".format(
db_name
)
@@ -570,10 +604,11 @@ class RDBMSDatabase(BaseConnect):
"""
return self._inspector.get_table_comment(table_name)
def get_column_comments(self, db_name, table_name):
def get_column_comments(self, db_name: str, table_name: str):
"""Return column comments."""
cursor = self.session.execute(
text(
f"""SELECT column_name, column_comment FROM information_schema.columns
f"""SELECT column_name, column_comment FROM information_schema.columns
WHERE table_schema = '{db_name}' and table_name = '{table_name}'
""".format(
db_name, table_name
@@ -585,17 +620,12 @@ class RDBMSDatabase(BaseConnect):
(column_comment[0], column_comment[1]) for column_comment in column_comments
]
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]
def get_database_names(self) -> List[str]:
"""Return a list of database names available in the database.
def get_database_names(self):
Returns:
List[str]: database list
"""
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()

View File

@@ -1,18 +1,20 @@
"""Clickhouse connector."""
import logging
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple
import sqlparse
from sqlalchemy import MetaData, text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.storage.schema import DBType
from .base import RDBMSConnector
class ClickhouseConnect(RDBMSDatabase):
"""Connect Clickhouse Database fetch MetaData
Args:
Usage:
"""
logger = logging.getLogger(__name__)
class ClickhouseConnector(RDBMSConnector):
"""Clickhouse connector."""
"""db type"""
db_type: str = "clickhouse"
@@ -24,6 +26,7 @@ class ClickhouseConnect(RDBMSDatabase):
client: Any = None
def __init__(self, client, **kwargs):
"""Create a new ClickhouseConnector from client."""
self.client = client
self._all_tables = set()
@@ -49,7 +52,8 @@ class ClickhouseConnect(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> "ClickhouseConnector":
"""Create a new ClickhouseConnector from host, port, user, pwd, db_name."""
import clickhouse_connect
from clickhouse_connect.driver import httputil
@@ -70,11 +74,6 @@ class ClickhouseConnect(RDBMSDatabase):
cls.client = client
return cls(client, **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
pass
def get_table_names(self):
"""Get all table names."""
session = self.client
@@ -85,6 +84,7 @@ class ClickhouseConnect(RDBMSDatabase):
def get_indexes(self, table_name: str) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name (str): table name
Returns:
@@ -93,10 +93,11 @@ class ClickhouseConnect(RDBMSDatabase):
session = self.client
_query_sql = f"""
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
SELECT name AS table, primary_key, from system.tables where
database ='{self.client.database}' and table = '{table_name}'
"""
with session.query_row_block_stream(_query_sql) as stream:
indexes = [block for block in stream]
indexes = [block for block in stream] # noqa
return [
{"name": "primary_key", "column_names": column_names.split(",")}
for table, column_names in indexes[0]
@@ -104,6 +105,7 @@ class ClickhouseConnect(RDBMSDatabase):
@property
def table_info(self) -> str:
"""Get table info."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
@@ -117,7 +119,7 @@ class ClickhouseConnect(RDBMSDatabase):
demonstrated in the paper.
"""
# TODO:
pass
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
@@ -133,11 +135,14 @@ class ClickhouseConnect(RDBMSDatabase):
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns.
Args:
table_name (str): str
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
fields = self.get_fields(table_name)
return [
@@ -150,18 +155,21 @@ class ClickhouseConnect(RDBMSDatabase):
session = self.client
_query_sql = f"""
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
SELECT name, type, default_expression, is_in_primary_key, comment
from system.columns where table='{table_name}'
""".format(
table_name
)
with session.query_row_block_stream(_query_sql) as stream:
fields = [block for block in stream]
fields = [block for block in stream] # noqa
return fields
def get_users(self):
"""Get user info."""
return []
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
@@ -169,9 +177,11 @@ class ClickhouseConnect(RDBMSDatabase):
return "UTF-8"
def get_charset(self):
"""Get character_set."""
return "UTF-8"
def get_database_list(self):
def get_database_names(self):
"""Get database names."""
session = self.client
with session.command("SHOW DATABASES") as stream:
@@ -184,12 +194,10 @@ class ClickhouseConnect(RDBMSDatabase):
]
return databases
def get_database_names(self):
return self.get_database_list()
def run(self, command: str, fetch: str = "all") -> List:
"""Execute sql command."""
# TODO need to be implemented
print("SQL:" + command)
logger.info("SQL:" + command)
if not command or len(command) < 0:
return []
_, ttype, sql_type, table_name = self.__sql_parse(command)
@@ -199,10 +207,12 @@ class ClickhouseConnect(RDBMSDatabase):
else:
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
logger.info(f"write result query:{select_sql}")
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
logger.info(
"DDL execution determines whether to enable through configuration "
)
cursor = self.client.command(command)
@@ -212,7 +222,7 @@ class ClickhouseConnect(RDBMSDatabase):
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
logger.info("DDL Result:" + str(result))
if not result:
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
@@ -225,13 +235,16 @@ class ClickhouseConnect(RDBMSDatabase):
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_current_db_name(self):
"""Get current database name."""
return self.client.database
def get_table_comments(self, db_name: str):
"""Get table comments."""
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
SELECT table, comment FROM system.tables WHERE database = '{db_name}'
""".format(
db_name
)
@@ -241,6 +254,7 @@ class ClickhouseConnect(RDBMSDatabase):
def get_table_comment(self, table_name: str) -> Dict:
"""Get table comment.
Args:
table_name (str): table name
Returns:
@@ -249,7 +263,9 @@ class ClickhouseConnect(RDBMSDatabase):
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
SELECT table, comment FROM system.tables WHERE
database = '{self.client.database}'and table = '{table_name}'
""".format(
self.client.database
)
@@ -258,9 +274,11 @@ class ClickhouseConnect(RDBMSDatabase):
return [{"text": comment} for table_name, comment in table_comments][0]
def get_column_comments(self, db_name, table_name):
"""Get column comments."""
session = self.client
_query_sql = f"""
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
select name column, comment from system.columns where database='{db_name}'
and table='{table_name}'
""".format(
db_name, table_name
)
@@ -270,10 +288,13 @@ class ClickhouseConnect(RDBMSDatabase):
return column_comments
def table_simple_info(self):
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
"""Get table simple info."""
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray
# instead; and quotes need to be escaped
_sql = f"""
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
SELECT concat(TABLE_NAME, '(', arrayStringConcat(
groupArray(column_name), '-'), ')') AS schema_info
FROM INFORMATION_SCHEMA.COLUMNS
WHERE table_schema = '{self.get_current_db_name()}'
GROUP BY TABLE_NAME
@@ -282,18 +303,18 @@ class ClickhouseConnect(RDBMSDatabase):
return [row[0] for block in stream for row in block]
def _write(self, write_sql: str):
"""write data
"""Execute write sql.
Args:
write_sql (str): sql string
"""
# TODO need to be implemented
print(f"Write[{write_sql}]")
logger.info(f"Write[{write_sql}]")
result = self.client.command(write_sql)
print(f"SQL[{write_sql}], result:{result.written_rows}")
logger.info(f"SQL[{write_sql}], result:{result.written_rows}")
def _query(self, query: str, fetch: str = "all"):
"""Query data from clickhouse
"""Query data from clickhouse.
Args:
query (str): sql string
@@ -306,7 +327,7 @@ class ClickhouseConnect(RDBMSDatabase):
_type_: List<Result>
"""
# TODO need to be implemented
print(f"Query[{query}]")
logger.info(f"Query[{query}]")
if not query:
return []
@@ -334,11 +355,13 @@ class ClickhouseConnect(RDBMSDatabase):
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
logger.info(
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
)
return parsed, ttype, sql_type, table_name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
"""Read table information from database."""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema

View File

@@ -1,13 +1,16 @@
from typing import Any, Dict, Iterable, List, Optional, Tuple
"""Doris connector."""
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
class DorisConnect(RDBMSDatabase):
class DorisConnector(RDBMSConnector):
"""Doris connector."""
driver = "doris"
db_type = "doris"
db_dialect = "doris"
@@ -22,24 +25,27 @@ class DorisConnect(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> "DorisConnector":
"""Create a new DorisConnector from host, port, user, pwd, db_name."""
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)
return cast(DorisConnector, cls.from_uri(db_url, engine_args, **kwargs))
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.get_session().execute(
text(
f"SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=database()"
"SELECT TABLE_NAME FROM information_schema.tables where "
"TABLE_SCHEMA=database()"
)
)
table_results = set(row[0] for row in table_results)
table_results = set(row[0] for row in table_results) # noqa: C401
self._all_tables = table_results
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
"""Get grants."""
cursor = self.get_session().execute(text("SHOW GRANTS"))
grants = cursor.fetchall()
if len(grants) == 0:
@@ -51,14 +57,17 @@ class DorisConnect(RDBMSDatabase):
return grants_list
def _get_current_version(self):
"""Get database current version"""
"""Get database current version."""
return int(
self.get_session().execute(text("select current_version()")).scalar()
)
def get_collation(self):
"""Get collation.
ref: https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-reference/Show-Statements/SHOW-COLLATION/
ref `SHOW COLLATION <https://doris.apache.org/zh-CN/docs/dev/sql-manual/
sql-reference/Show-Statements/SHOW-COLLATION/>`_
"""
cursor = self.get_session().execute(text("SHOW COLLATION"))
results = cursor.fetchall()
@@ -70,11 +79,14 @@ class DorisConnect(RDBMSDatabase):
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns.
Args:
table_name (str): str
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
fields = self.get_fields(table_name)
return [
@@ -92,8 +104,8 @@ class DorisConnect(RDBMSDatabase):
"""Get column fields about specified table."""
cursor = self.get_session().execute(
text(
f"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT "
f"from information_schema.columns "
"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
"COLUMN_COMMENT from information_schema.columns "
f'where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
)
)
@@ -104,7 +116,8 @@ class DorisConnect(RDBMSDatabase):
"""Get character_set."""
return "utf-8"
def get_show_create_table(self, table_name):
def get_show_create_table(self, table_name) -> str:
"""Get show create table."""
# cur = self.get_session().execute(
# text(
# f"""show create table {table_name}"""
@@ -128,6 +141,7 @@ class DorisConnect(RDBMSDatabase):
return ""
def get_table_comments(self, db_name=None):
"""Get table comments."""
db_name = "database()" if not db_name else f"'{db_name}'"
cursor = self.get_session().execute(
text(
@@ -139,10 +153,8 @@ class DorisConnect(RDBMSDatabase):
tables = cursor.fetchall()
return [(table[0], table[1]) for table in tables]
def get_database_list(self):
return self.get_database_names()
def get_database_names(self):
"""Get database names."""
cursor = self.get_session().execute(text("SHOW DATABASES"))
results = cursor.fetchall()
return [
@@ -160,15 +172,17 @@ class DorisConnect(RDBMSDatabase):
]
def get_current_db_name(self) -> str:
"""Get current database name."""
return self.get_session().execute(text("select database()")).scalar()
def table_simple_info(self):
"""Get table simple info."""
cursor = self.get_session().execute(
text(
f"SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') "
f"FROM information_schema.columns "
f"where TABLE_SCHEMA=database() "
f"GROUP BY TABLE_NAME"
"SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') "
"FROM information_schema.columns "
"where TABLE_SCHEMA=database() "
"GROUP BY TABLE_NAME"
)
)
results = cursor.fetchall()

View File

@@ -1,15 +1,13 @@
"""DuckDB connector."""
from typing import Any, Iterable, Optional
from sqlalchemy import create_engine, text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
class DuckDbConnect(RDBMSDatabase):
"""Connect Duckdb Database fetch MetaData
Args:
Usage:
"""
class DuckDbConnector(RDBMSConnector):
"""DuckDB connector."""
db_type: str = "duckdb"
db_dialect: str = "duckdb"
@@ -17,21 +15,24 @@ class DuckDbConnect(RDBMSDatabase):
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
) -> RDBMSConnector:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
def get_users(self):
"""Get users."""
cursor = self.session.execute(
text(
f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
"SELECT * FROM sqlite_master WHERE type = 'table' AND "
"name = 'duckdb_sys_users';"
)
)
users = cursor.fetchall()
return [(user[0], user[1]) for user in users]
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
@@ -39,12 +40,14 @@ class DuckDbConnect(RDBMSDatabase):
return "UTF-8"
def get_charset(self):
"""Get character_set of current database."""
return "UTF-8"
def get_table_comments(self, db_name):
def get_table_comments(self, db_name: str):
"""Get table comments."""
cursor = self.session.execute(
text(
f"""
"""
SELECT name, sql FROM sqlite_master WHERE type='table'
"""
)
@@ -55,7 +58,8 @@ class DuckDbConnect(RDBMSDatabase):
]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
"""Get table simple info."""
_tables_sql = """
SELECT name FROM sqlite_master WHERE type='table'
"""
cursor = self.session.execute(text(_tables_sql))

View File

@@ -1,14 +1,13 @@
from typing import Any, Optional
"""Hive Connector."""
from typing import Any, Optional, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
class HiveConnect(RDBMSDatabase):
"""db type"""
class HiveConnector(RDBMSConnector):
"""Hive connector."""
db_type: str = "hive"
"""db driver"""
@@ -26,28 +25,26 @@ class HiveConnect(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from uri database.
Args:
host (str): database host.
port (int): database port.
user (str): database user.
pwd (str): database password.
db_name (str): database name.
engine_args (Optional[dict]):other engine_args.
"""
) -> "HiveConnector":
"""Create a new HiveConnector from host, port, user, pwd, db_name."""
db_url: str = f"{cls.driver}://{host}:{str(port)}/{db_name}"
if user and pwd:
db_url: str = f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
return cls.from_uri(db_url, engine_args, **kwargs)
db_url = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/"
f"{db_name}"
)
return cast(HiveConnector, cls.from_uri(db_url, engine_args, **kwargs))
def table_simple_info(self):
"""Get table simple info."""
return []
def get_users(self):
"""Get users."""
return []
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
@@ -55,4 +52,5 @@ class HiveConnect(RDBMSDatabase):
return "UTF-8"
def get_charset(self):
"""Get character_set of current database."""
return "UTF-8"

View File

@@ -1,17 +1,13 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Any, Iterable, Optional
"""MSSQL connector."""
from typing import Iterable
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
from sqlalchemy import text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
class MSSQLConnect(RDBMSDatabase):
"""Connect MSSQL Database fetch MetaData
Args:
Usage:
"""
class MSSQLConnector(RDBMSConnector):
"""MSSQL connector."""
db_type: str = "mssql"
db_dialect: str = "mssql"
@@ -20,8 +16,10 @@ class MSSQLConnect(RDBMSDatabase):
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
"""Get table simple info."""
_tables_sql = """
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE
TABLE_TYPE='BASE TABLE'
"""
cursor = self.session.execute(text(_tables_sql))
tables_results = cursor.fetchall()
@@ -29,7 +27,8 @@ class MSSQLConnect(RDBMSDatabase):
for row in tables_results:
table_name = row[0]
_sql = f"""
SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}'
SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE
TABLE_NAME='{table_name}'
"""
cursor_colums = self.session.execute(text(_sql))
colum_results = cursor_colums.fetchall()

View File

@@ -1,13 +1,10 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from dbgpt.datasource.rdbms.base import RDBMSDatabase
"""MySQL connector."""
from .base import RDBMSConnector
class MySQLConnect(RDBMSDatabase):
"""Connect MySQL Database fetch MetaData
Args:
Usage:
"""
class MySQLConnector(RDBMSConnector):
"""MySQL connector."""
db_type: str = "mysql"
db_dialect: str = "mysql"

View File

@@ -1,13 +1,19 @@
from typing import Any, Iterable, List, Optional, Tuple
"""PostgreSQL connector."""
import logging
from typing import Any, Iterable, List, Optional, Tuple, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
logger = logging.getLogger(__name__)
class PostgreSQLDatabase(RDBMSDatabase):
class PostgreSQLConnector(RDBMSConnector):
"""PostgreSQL connector."""
driver = "postgresql+psycopg2"
db_type = "postgresql"
db_dialect = "postgresql"
@@ -22,34 +28,38 @@ class PostgreSQLDatabase(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> "PostgreSQLConnector":
"""Create a new PostgreSQLConnector from host, port, user, pwd, db_name."""
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)
return cast(PostgreSQLConnector, cls.from_uri(db_url, engine_args, **kwargs))
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
text(
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
"SELECT tablename FROM pg_catalog.pg_tables WHERE "
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
)
)
view_results = self.session.execute(
text(
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
"SELECT viewname FROM pg_catalog.pg_views WHERE "
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
)
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
table_results = set(row[0] for row in table_results) # noqa: C401
view_results = set(row[0] for row in view_results) # noqa: C401
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
"""Get grants."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
"""
SELECT DISTINCT grantee, privilege_type
FROM information_schema.role_table_grants
WHERE grantee = CURRENT_USER;"""
@@ -64,13 +74,14 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
"SELECT datcollate AS collation FROM pg_database WHERE "
"datname = current_database();"
)
)
collation = cursor.fetchone()[0]
return collation
except Exception as e:
print("postgresql get collation error: ", e)
logger.warning(f"postgresql get collation error: {str(e)}")
return None
def get_users(self):
@@ -82,7 +93,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
users = cursor.fetchall()
return [user[0] for user in users]
except Exception as e:
print("postgresql get users error: ", e)
logger.warning(f"postgresql get users error: {str(e)}")
return []
def get_fields(self, table_name) -> List[Tuple]:
@@ -90,7 +101,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
"SELECT column_name, data_type, column_default, is_nullable, "
"column_name as column_comment \
FROM information_schema.columns WHERE table_name = :table_name",
),
{"table_name": table_name},
@@ -103,23 +115,28 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE "
"datname = current_database();"
)
)
character_set = cursor.fetchone()[0]
return character_set
def get_show_create_table(self, table_name):
def get_show_create_table(self, table_name: str):
"""Return show create table."""
cur = self.session.execute(
text(
f"""
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
SELECT a.attname as column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
FROM pg_catalog.pg_attribute a
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
SELECT max(a.attnum)
FROM pg_catalog.pg_attribute a
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class
WHERE relname='{table_name}')
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class
WHERE relname='{table_name}')
"""
)
)
@@ -133,6 +150,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
return create_table_query
def get_table_comments(self, db_name=None):
"""Get table comments."""
tablses = self.table_simple_info()
comments = []
for table in tablses:
@@ -141,15 +159,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
comments.append((table_name, table_comment))
return comments
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
return [
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
]
def get_database_names(self):
"""Get database names."""
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
@@ -158,10 +169,12 @@ class PostgreSQLDatabase(RDBMSDatabase):
]
def get_current_db_name(self) -> str:
"""Get current database name."""
return self.session.execute(text("SELECT current_database()")).scalar()
def table_simple_info(self):
_sql = f"""
"""Get table simple info."""
_sql = """
SELECT table_name, string_agg(column_name, ', ') AS schema_info
FROM (
SELECT c.relname AS table_name, a.attname AS column_name
@@ -181,17 +194,18 @@ class PostgreSQLDatabase(RDBMSDatabase):
results = cursor.fetchall()
return results
def get_fields(self, table_name, schema_name="public"):
def get_fields_wit_schema(self, table_name, schema_name="public"):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
FROM information_schema.columns c
LEFT JOIN pg_catalog.pg_description d
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable,
d.description FROM information_schema.columns c
LEFT JOIN pg_catalog.pg_description d
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid
AND c.ordinal_position = d.objsubid
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
"""
)
)
@@ -203,7 +217,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
f"SELECT indexname, indexdef FROM pg_indexes WHERE "
f"tablename = '{table_name}'"
)
)
indexes = cursor.fetchall()

View File

@@ -1,6 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""SQLite connector."""
import logging
import os
import tempfile
@@ -8,16 +6,13 @@ from typing import Any, Iterable, List, Optional, Tuple
from sqlalchemy import create_engine, text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
logger = logging.getLogger(__name__)
class SQLiteConnect(RDBMSDatabase):
"""Connect SQLite Database fetch MetaData
Args:
Usage:
"""
class SQLiteConnector(RDBMSConnector):
"""SQLite connector."""
db_type: str = "sqlite"
db_dialect: str = "sqlite"
@@ -25,8 +20,8 @@ class SQLiteConnect(RDBMSDatabase):
@classmethod
def from_file_path(
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> RDBMSDatabase:
"""Construct a SQLAlchemy engine from URI."""
) -> "SQLiteConnector":
"""Create a new SQLiteConnector from file path."""
_engine_args = engine_args or {}
_engine_args["connect_args"] = {"check_same_thread": False}
# _engine_args["echo"] = True
@@ -52,7 +47,8 @@ class SQLiteConnect(RDBMSDatabase):
"""Get table show create table about specified table."""
cursor = self.session.execute(
text(
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'"
"SELECT sql FROM sqlite_master WHERE type='table' "
f"AND name='{table_name}'"
)
)
ans = cursor.fetchall()
@@ -62,7 +58,7 @@ class SQLiteConnect(RDBMSDatabase):
"""Get column fields about specified table."""
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
fields = cursor.fetchall()
print(fields)
logger.info(fields)
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
def get_simple_fields(self, table_name):
@@ -70,9 +66,11 @@ class SQLiteConnect(RDBMSDatabase):
return self.get_fields(table_name)
def get_users(self):
"""Get user info."""
return []
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
@@ -80,12 +78,11 @@ class SQLiteConnect(RDBMSDatabase):
return "UTF-8"
def get_charset(self):
"""Get character_set of current database."""
return "UTF-8"
def get_database_list(self):
return []
def get_database_names(self):
"""Get database names."""
return []
def _sync_tables_from_db(self) -> Iterable[str]:
@@ -95,25 +92,27 @@ class SQLiteConnect(RDBMSDatabase):
view_results = self.session.execute(
text("SELECT name FROM sqlite_master WHERE type='view'")
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
table_results = set(row[0] for row in table_results) # noqa
view_results = set(row[0] for row in view_results) # noqa
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables
def _write(self, write_sql):
print(f"Write[{write_sql}]")
logger.info(f"Write[{write_sql}]")
session = self.session
result = session.execute(text(write_sql))
session.commit()
# TODO Subsequent optimization of dynamically specified database submission loss target problem
print(f"SQL[{write_sql}], result:{result.rowcount}")
# TODO Subsequent optimization of dynamically specified database submission
# loss target problem
logger.info(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def get_table_comments(self, db_name=None):
"""Get table comments."""
cursor = self.session.execute(
text(
f"""
"""
SELECT name, sql FROM sqlite_master WHERE type='table'
"""
)
@@ -124,7 +123,8 @@ class SQLiteConnect(RDBMSDatabase):
]
def table_simple_info(self) -> Iterable[str]:
_tables_sql = f"""
"""Get table simple info."""
_tables_sql = """
SELECT name FROM sqlite_master WHERE type='table'
"""
cursor = self.session.execute(text(_tables_sql))
@@ -146,10 +146,14 @@ class SQLiteConnect(RDBMSDatabase):
return results
class SQLiteTempConnect(SQLiteConnect):
"""A temporary SQLite database connection. The database file will be deleted when the connection is closed."""
class SQLiteTempConnector(SQLiteConnector):
"""A temporary SQLite database connection.
The database file will be deleted when the connection is closed.
"""
def __init__(self, engine, temp_file_path, *args, **kwargs):
"""Construct a temporary SQLite database connection."""
super().__init__(engine, *args, **kwargs)
self.temp_file_path = temp_file_path
self._is_closed = False
@@ -157,7 +161,7 @@ class SQLiteTempConnect(SQLiteConnect):
@classmethod
def create_temporary_db(
cls, engine_args: Optional[dict] = None, **kwargs: Any
) -> "SQLiteTempConnect":
) -> "SQLiteTempConnector":
"""Create a temporary SQLite database with a temporary file.
Examples:
@@ -175,7 +179,7 @@ class SQLiteTempConnect(SQLiteConnect):
engine_args (Optional[dict]): SQLAlchemy engine arguments.
Returns:
SQLiteTempConnect: A SQLiteTempConnect instance.
SQLiteTempConnector: A SQLiteTempConnect instance.
"""
_engine_args = engine_args or {}
_engine_args["connect_args"] = {"check_same_thread": False}
@@ -219,7 +223,7 @@ class SQLiteTempConnect(SQLiteConnect):
],
},
}
with SQLiteTempConnect.create_temporary_db() as db:
with SQLiteTempConnector.create_temporary_db() as db:
db.create_temp_tables(tables_info)
field_names, result = db.query_ex(db.session, "select * from test")
assert field_names == ["id", "name", "age"]
@@ -248,14 +252,18 @@ class SQLiteTempConnect(SQLiteConnect):
self._sync_tables_from_db()
def __enter__(self):
"""Return the connection when entering the context manager."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Close the connection when exiting the context manager."""
self.close()
def __del__(self):
"""Close the connection when the object is deleted."""
self.close()
@classmethod
def is_normal_type(cls) -> bool:
"""Return whether the connector is a normal type."""
return False

View File

@@ -1,21 +1,24 @@
from typing import Any, Iterable, List, Optional, Tuple
"""StarRocks connector."""
from typing import Any, Iterable, List, Optional, Tuple, Type, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import *
from .base import RDBMSConnector
from .dialect.starrocks.sqlalchemy import * # noqa
class StarRocksConnect(RDBMSDatabase):
class StarRocksConnector(RDBMSConnector):
"""StarRocks connector."""
driver = "starrocks"
db_type = "starrocks"
db_dialect = "starrocks"
@classmethod
def from_uri_db(
cls,
cls: Type["StarRocksConnector"],
host: str,
port: int,
user: str,
@@ -23,27 +26,31 @@ class StarRocksConnect(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> "StarRocksConnector":
"""Create a new StarRocksConnector from host, port, user, pwd, db_name."""
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)
return cast(StarRocksConnector, cls.from_uri(db_url, engine_args, **kwargs))
def _sync_tables_from_db(self) -> Iterable[str]:
db_name = self.get_current_db_name()
table_results = self.session.execute(
text(
f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
"SELECT TABLE_NAME FROM information_schema.tables where "
f'TABLE_SCHEMA="{db_name}"'
)
)
# view_results = self.session.execute(text(f'SELECT TABLE_NAME from information_schema.materialized_views where TABLE_SCHEMA="{db_name}"'))
table_results = set(row[0] for row in table_results)
# view_results = self.session.execute(text(f'SELECT TABLE_NAME from
# information_schema.materialized_views where TABLE_SCHEMA="{db_name}"'))
table_results = set(row[0] for row in table_results) # noqa: C401
# view_results = set(row[0] for row in view_results)
self._all_tables = table_results
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
"""Get grants."""
session = self._db_sessions()
cursor = session.execute(text("SHOW GRANTS"))
grants = cursor.fetchall()
@@ -56,7 +63,7 @@ class StarRocksConnect(RDBMSDatabase):
return grants_list
def _get_current_version(self):
"""Get database current version"""
"""Get database current version."""
return int(self.session.execute(text("select current_version()")).scalar())
def get_collation(self):
@@ -75,7 +82,9 @@ class StarRocksConnect(RDBMSDatabase):
db_name = f'"{db_name}"'
cursor = session.execute(
text(
f'select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.columns where TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}'
"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, "
"COLUMN_COMMENT from information_schema.columns where "
f'TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}'
)
)
fields = cursor.fetchall()
@@ -83,10 +92,10 @@ class StarRocksConnect(RDBMSDatabase):
def get_charset(self):
"""Get character_set."""
return "utf-8"
def get_show_create_table(self, table_name):
def get_show_create_table(self, table_name: str):
"""Get show create table."""
# cur = self.session.execute(
# text(
# f"""show create table {table_name}"""
@@ -99,7 +108,8 @@ class StarRocksConnect(RDBMSDatabase):
# 这里是要表描述, 返回建表语句会导致token过长而失败
cur = self.session.execute(
text(
f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
"SELECT TABLE_COMMENT FROM information_schema.tables where "
f'TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
)
)
table = cur.fetchone()
@@ -109,20 +119,20 @@ class StarRocksConnect(RDBMSDatabase):
return ""
def get_table_comments(self, db_name=None):
"""Get table comments."""
if not db_name:
db_name = self.get_current_db_name()
cur = self.session.execute(
text(
f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
"SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables "
f'where TABLE_SCHEMA="{db_name}"'
)
)
tables = cur.fetchall()
return [(table[0], table[1]) for table in tables]
def get_database_list(self):
return self.get_database_names()
def get_database_names(self):
"""Get database names."""
session = self._db_sessions()
cursor = session.execute(text("SHOW DATABASES;"))
results = cursor.fetchall()
@@ -133,11 +143,14 @@ class StarRocksConnect(RDBMSDatabase):
]
def get_current_db_name(self) -> str:
"""Get current database name."""
return self.session.execute(text("select database()")).scalar()
def table_simple_info(self):
_sql = f"""
SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");") FROM information_schema.columns where TABLE_SCHEMA=database()
"""Get table simple info."""
_sql = """
SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");")
FROM information_schema.columns where TABLE_SCHEMA=database()
GROUP BY TABLE_NAME
"""
cursor = self.session.execute(text(_sql))

View File

@@ -0,0 +1 @@
"""Module for RDBMS dialects."""

View File

@@ -1,4 +1,4 @@
#! /usr/bin/python3
"""StarRocks dialect for SQLAlchemy."""
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,4 +1,4 @@
#! /usr/bin/python3
"""SQLAlchemy dialect for StarRocks."""
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");

View File

@@ -1,3 +1,5 @@
"""SQLAlchemy data types for StarRocks."""
import logging
import re
from typing import Any, Dict, List, Optional, Type
@@ -10,50 +12,71 @@ logger = logging.getLogger(__name__)
class TINYINT(Integer): # pylint: disable=no-init
"""StarRocks TINYINT type."""
__visit_name__ = "TINYINT"
class LARGEINT(Integer): # pylint: disable=no-init
"""StarRocks LARGEINT type."""
__visit_name__ = "LARGEINT"
class DOUBLE(Float): # pylint: disable=no-init
"""StarRocks DOUBLE type."""
__visit_name__ = "DOUBLE"
class HLL(Numeric): # pylint: disable=no-init
"""StarRocks HLL type."""
__visit_name__ = "HLL"
class BITMAP(Numeric): # pylint: disable=no-init
"""StarRocks BITMAP type."""
__visit_name__ = "BITMAP"
class PERCENTILE(Numeric): # pylint: disable=no-init
"""StarRocks PERCENTILE type."""
__visit_name__ = "PERCENTILE"
class ARRAY(TypeEngine): # pylint: disable=no-init
"""StarRocks ARRAY type."""
__visit_name__ = "ARRAY"
@property
def python_type(self) -> Optional[Type[List[Any]]]:
def python_type(self) -> Optional[Type[List[Any]]]: # type: ignore
"""Return the Python type for this SQL type."""
return list
class MAP(TypeEngine): # pylint: disable=no-init
"""StarRocks MAP type."""
__visit_name__ = "MAP"
@property
def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
def python_type(self) -> Optional[Type[Dict[Any, Any]]]: # type: ignore
"""Return the Python type for this SQL type."""
return dict
class STRUCT(TypeEngine): # pylint: disable=no-init
"""StarRocks STRUCT type."""
__visit_name__ = "STRUCT"
@property
def python_type(self) -> Optional[Type[Any]]:
def python_type(self) -> Optional[Type[Any]]: # type: ignore
"""Return the Python type for this SQL type."""
return None
@@ -90,6 +113,7 @@ _type_map = {
def parse_sqltype(type_str: str) -> TypeEngine:
"""Parse a SQL type string into a SQLAlchemy type object."""
type_str = type_str.strip().lower()
match = re.match(r"^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?", type_str)
if not match:

View File

@@ -1,4 +1,4 @@
#! /usr/bin/python3
"""StarRocks dialect for SQLAlchemy."""
# Copyright 2021-present StarRocks, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, cast
from sqlalchemy import exc, log, text
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
@@ -25,7 +25,9 @@ logger = logging.getLogger(__name__)
@log.class_logger
class StarRocksDialect(MySQLDialect_pymysql):
class StarRocksDialect(MySQLDialect_pymysql): # type: ignore
"""StarRocks dialect for SQLAlchemy."""
# Caching
# Warnings are generated by SQLAlchmey if this flag is not explicitly set
# and tests are needed before being enabled
@@ -34,9 +36,11 @@ class StarRocksDialect(MySQLDialect_pymysql):
name = "starrocks"
def __init__(self, *args, **kw):
"""Create a new StarRocks dialect."""
super(StarRocksDialect, self).__init__(*args, **kw)
def has_table(self, connection, table_name, schema=None, **kw):
def has_table(self, connection, table_name, schema: Optional[str] = None, **kw):
"""Return True if the given table is present in the database."""
self._ensure_has_table_connection(connection)
if schema is None:
@@ -53,15 +57,13 @@ class StarRocksDialect(MySQLDialect_pymysql):
return res.first() is not None
def get_schema_names(self, connection, **kw):
"""Return a list of schema names available in the database."""
rp = connection.exec_driver_sql("SHOW schemas")
return [r[0] for r in rp]
def get_table_names(self, connection, schema=None, **kw):
def get_table_names(self, connection, schema: Optional[str] = None, **kw):
"""Return a Unicode SHOW TABLES from a given schema."""
if schema is not None:
current_schema = schema
else:
current_schema = self.default_schema_name
current_schema: str = cast(str, schema or self.default_schema_name)
charset = self._connection_charset
@@ -76,13 +78,15 @@ class StarRocksDialect(MySQLDialect_pymysql):
if row[1] == "BASE TABLE"
]
def get_view_names(self, connection, schema=None, **kw):
def get_view_names(self, connection, schema: Optional[str] = None, **kw):
"""Return a Unicode SHOW TABLES from a given schema."""
if schema is None:
schema = self.default_schema_name
current_schema = cast(str, schema)
charset = self._connection_charset
rp = connection.exec_driver_sql(
"SHOW FULL TABLES FROM %s"
% self.identifier_preparer.quote_identifier(schema)
% self.identifier_preparer.quote_identifier(current_schema)
)
return [
row[0]
@@ -90,9 +94,14 @@ class StarRocksDialect(MySQLDialect_pymysql):
if row[1] in ("VIEW", "SYSTEM VIEW")
]
def get_columns(
self, connection: Connection, table_name: str, schema: str = None, **kw
) -> List[Dict[str, Any]]:
def get_columns( # type: ignore
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw,
) -> List[Dict[str, Any]]: # type: ignore
"""Return information about columns in `table_name`."""
if not self.has_table(connection, table_name, schema):
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
schema = schema or self._get_default_schema_name(connection)
@@ -114,60 +123,100 @@ class StarRocksDialect(MySQLDialect_pymysql):
columns.append(column)
return columns
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
def get_pk_constraint(
self, connection, table_name, schema: Optional[str] = None, **kw
):
"""Return information about the primary key constraint."""
return { # type: ignore # pep-655 not supported
"name": None,
"constrained_columns": [],
}
def get_unique_constraints(
self, connection: Connection, table_name: str, schema: str = None, **kw
def get_unique_constraints( # type: ignore
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw,
) -> List[Dict[str, Any]]:
"""Return information about unique constraints."""
return []
def get_check_constraints(
self, connection: Connection, table_name: str, schema: str = None, **kw
def get_check_constraints( # type: ignore
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw,
) -> List[Dict[str, Any]]:
"""Return information about check constraints."""
return []
def get_foreign_keys(
self, connection: Connection, table_name: str, schema: str = None, **kw
def get_foreign_keys( # type: ignore
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw,
) -> List[Dict[str, Any]]:
"""Return information about foreign keys."""
return []
def get_primary_keys(
self, connection: Connection, table_name: str, schema: str = None, **kw
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw,
) -> List[str]:
"""Return the primary key columns of the given table."""
pk = self.get_pk_constraint(connection, table_name, schema)
return pk.get("constrained_columns") # type: ignore
def get_indexes(self, connection, table_name, schema=None, **kw):
def get_indexes(self, connection, table_name, schema: Optional[str] = None, **kw):
"""Get table indexes about specified table."""
return []
def has_sequence(
self, connection: Connection, sequence_name: str, schema: str = None, **kw
self,
connection: Connection,
sequence_name: str,
schema: Optional[str] = None,
**kw,
) -> bool:
"""Return True if the given sequence is present in the database."""
return False
def get_sequence_names(
self, connection: Connection, schema: str = None, **kw
self, connection: Connection, schema: Optional[str] = None, **kw
) -> List[str]:
"""Return a list of sequence names."""
return []
def get_temp_view_names(
self, connection: Connection, schema: str = None, **kw
self, connection: Connection, schema: Optional[str] = None, **kw
) -> List[str]:
"""Return a list of temporary view names."""
return []
def get_temp_table_names(
self, connection: Connection, schema: str = None, **kw
self, connection: Connection, schema: Optional[str] = None, **kw
) -> List[str]:
"""Return a list of temporary table names."""
return []
def get_table_options(self, connection, table_name, schema=None, **kw):
def get_table_options(
self, connection, table_name, schema: Optional[str] = None, **kw
):
"""Return a dictionary of options specified when the table was created."""
return {}
def get_table_comment(
self, connection: Connection, table_name: str, schema: str = None, **kw
def get_table_comment( # type: ignore
self,
connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw,
) -> Dict[str, Any]:
"""Return the comment for a table."""
return dict(text=None)

View File

@@ -6,14 +6,14 @@ import tempfile
import pytest
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnector
@pytest.fixture
def db():
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
temp_db_file.close()
conn = DuckDbConnect.from_file_path(temp_db_file.name + "duckdb.db")
conn = DuckDbConnector.from_file_path(temp_db_file.name + "duckdb.db")
yield conn

View File

@@ -6,14 +6,14 @@ import tempfile
import pytest
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnector
@pytest.fixture
def db():
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
temp_db_file.close()
conn = SQLiteConnect.from_file_path(temp_db_file.name)
conn = SQLiteConnector.from_file_path(temp_db_file.name)
yield conn
try:
# TODO: Failed on windows
@@ -43,7 +43,7 @@ def test_run_sql(db):
def test_run_no_throw(db):
assert db.run_no_throw("this is a error sql").startswith("Error:")
assert db.run_no_throw("this is a error sql") == []
def test_get_indexes(db):
@@ -122,10 +122,6 @@ def test_get_table_comments(db):
]
def test_get_database_list(db):
db.get_database_list() == []
def test_get_database_names(db):
db.get_database_names() == []
@@ -134,11 +130,11 @@ def test_db_dir_exist_dir():
with tempfile.TemporaryDirectory() as temp_dir:
new_dir = os.path.join(temp_dir, "new_dir")
file_path = os.path.join(new_dir, "sqlite.db")
db = SQLiteConnect.from_file_path(file_path)
db = SQLiteConnector.from_file_path(file_path)
assert os.path.exists(new_dir) == True
assert list(db.get_table_names()) == []
with tempfile.TemporaryDirectory() as existing_dir:
file_path = os.path.join(existing_dir, "sqlite.db")
db = SQLiteConnect.from_file_path(file_path)
db = SQLiteConnector.from_file_path(file_path)
assert os.path.exists(existing_dir) == True
assert list(db.get_table_names()) == []