refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -1,18 +1,20 @@
"""Clickhouse connector."""
import logging
import re
from typing import Any, Dict, Iterable, List, Optional, Tuple
import sqlparse
from sqlalchemy import MetaData, text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.storage.schema import DBType
from .base import RDBMSConnector
class ClickhouseConnect(RDBMSDatabase):
"""Connect Clickhouse Database fetch MetaData
Args:
Usage:
"""
logger = logging.getLogger(__name__)
class ClickhouseConnector(RDBMSConnector):
"""Clickhouse connector."""
"""db type"""
db_type: str = "clickhouse"
@@ -24,6 +26,7 @@ class ClickhouseConnect(RDBMSDatabase):
client: Any = None
def __init__(self, client, **kwargs):
"""Create a new ClickhouseConnector from client."""
self.client = client
self._all_tables = set()
@@ -49,7 +52,8 @@ class ClickhouseConnect(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> "ClickhouseConnector":
"""Create a new ClickhouseConnector from host, port, user, pwd, db_name."""
import clickhouse_connect
from clickhouse_connect.driver import httputil
@@ -70,11 +74,6 @@ class ClickhouseConnect(RDBMSDatabase):
cls.client = client
return cls(client, **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
pass
def get_table_names(self):
"""Get all table names."""
session = self.client
@@ -85,6 +84,7 @@ class ClickhouseConnect(RDBMSDatabase):
def get_indexes(self, table_name: str) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name (str): table name
Returns:
@@ -93,10 +93,11 @@ class ClickhouseConnect(RDBMSDatabase):
session = self.client
_query_sql = f"""
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
SELECT name AS table, primary_key, from system.tables where
database ='{self.client.database}' and table = '{table_name}'
"""
with session.query_row_block_stream(_query_sql) as stream:
indexes = [block for block in stream]
indexes = [block for block in stream] # noqa
return [
{"name": "primary_key", "column_names": column_names.split(",")}
for table, column_names in indexes[0]
@@ -104,6 +105,7 @@ class ClickhouseConnect(RDBMSDatabase):
@property
def table_info(self) -> str:
"""Get table info."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
@@ -117,7 +119,7 @@ class ClickhouseConnect(RDBMSDatabase):
demonstrated in the paper.
"""
# TODO:
pass
return ""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
@@ -133,11 +135,14 @@ class ClickhouseConnect(RDBMSDatabase):
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns.
Args:
table_name (str): str
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
fields = self.get_fields(table_name)
return [
@@ -150,18 +155,21 @@ class ClickhouseConnect(RDBMSDatabase):
session = self.client
_query_sql = f"""
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
SELECT name, type, default_expression, is_in_primary_key, comment
from system.columns where table='{table_name}'
""".format(
table_name
)
with session.query_row_block_stream(_query_sql) as stream:
fields = [block for block in stream]
fields = [block for block in stream] # noqa
return fields
def get_users(self):
"""Get user info."""
return []
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
@@ -169,9 +177,11 @@ class ClickhouseConnect(RDBMSDatabase):
return "UTF-8"
def get_charset(self):
"""Get character_set."""
return "UTF-8"
def get_database_list(self):
def get_database_names(self):
"""Get database names."""
session = self.client
with session.command("SHOW DATABASES") as stream:
@@ -184,12 +194,10 @@ class ClickhouseConnect(RDBMSDatabase):
]
return databases
def get_database_names(self):
return self.get_database_list()
def run(self, command: str, fetch: str = "all") -> List:
"""Execute sql command."""
# TODO need to be implemented
print("SQL:" + command)
logger.info("SQL:" + command)
if not command or len(command) < 0:
return []
_, ttype, sql_type, table_name = self.__sql_parse(command)
@@ -199,10 +207,12 @@ class ClickhouseConnect(RDBMSDatabase):
else:
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
logger.info(f"write result query:{select_sql}")
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
logger.info(
"DDL execution determines whether to enable through configuration "
)
cursor = self.client.command(command)
@@ -212,7 +222,7 @@ class ClickhouseConnect(RDBMSDatabase):
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
logger.info("DDL Result:" + str(result))
if not result:
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
@@ -225,13 +235,16 @@ class ClickhouseConnect(RDBMSDatabase):
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_current_db_name(self):
"""Get current database name."""
return self.client.database
def get_table_comments(self, db_name: str):
"""Get table comments."""
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
SELECT table, comment FROM system.tables WHERE database = '{db_name}'
""".format(
db_name
)
@@ -241,6 +254,7 @@ class ClickhouseConnect(RDBMSDatabase):
def get_table_comment(self, table_name: str) -> Dict:
"""Get table comment.
Args:
table_name (str): table name
Returns:
@@ -249,7 +263,9 @@ class ClickhouseConnect(RDBMSDatabase):
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
SELECT table, comment FROM system.tables WHERE
database = '{self.client.database}'and table = '{table_name}'
""".format(
self.client.database
)
@@ -258,9 +274,11 @@ class ClickhouseConnect(RDBMSDatabase):
return [{"text": comment} for table_name, comment in table_comments][0]
def get_column_comments(self, db_name, table_name):
"""Get column comments."""
session = self.client
_query_sql = f"""
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
select name column, comment from system.columns where database='{db_name}'
and table='{table_name}'
""".format(
db_name, table_name
)
@@ -270,10 +288,13 @@ class ClickhouseConnect(RDBMSDatabase):
return column_comments
def table_simple_info(self):
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
"""Get table simple info."""
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray
# instead; and quotes need to be escaped
_sql = f"""
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
SELECT concat(TABLE_NAME, '(', arrayStringConcat(
groupArray(column_name), '-'), ')') AS schema_info
FROM INFORMATION_SCHEMA.COLUMNS
WHERE table_schema = '{self.get_current_db_name()}'
GROUP BY TABLE_NAME
@@ -282,18 +303,18 @@ class ClickhouseConnect(RDBMSDatabase):
return [row[0] for block in stream for row in block]
def _write(self, write_sql: str):
"""write data
"""Execute write sql.
Args:
write_sql (str): sql string
"""
# TODO need to be implemented
print(f"Write[{write_sql}]")
logger.info(f"Write[{write_sql}]")
result = self.client.command(write_sql)
print(f"SQL[{write_sql}], result:{result.written_rows}")
logger.info(f"SQL[{write_sql}], result:{result.written_rows}")
def _query(self, query: str, fetch: str = "all"):
"""Query data from clickhouse
"""Query data from clickhouse.
Args:
query (str): sql string
@@ -306,7 +327,7 @@ class ClickhouseConnect(RDBMSDatabase):
_type_: List<Result>
"""
# TODO need to be implemented
print(f"Query[{query}]")
logger.info(f"Query[{query}]")
if not query:
return []
@@ -334,11 +355,13 @@ class ClickhouseConnect(RDBMSDatabase):
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
logger.info(
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
)
return parsed, ttype, sql_type, table_name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
"""Read table information from database."""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema