fix: clickhouse connect error fix (#958)

Co-authored-by: Magic <magic@B-4TMH9N3X-2120.local>
Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
magic.chen
2023-12-22 11:44:26 +08:00
committed by GitHub
parent d9065227bd
commit 681a8e2ed5
26 changed files with 630 additions and 91 deletions

View File

@@ -1,10 +1,17 @@
import re
from typing import Optional, Any
import sqlparse
import clickhouse_connect
from typing import List, Optional, Any, Iterable, Dict
from sqlalchemy import text
from urllib.parse import quote
from sqlalchemy.schema import CreateTable
from urllib.parse import quote_plus as urlquote
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from clickhouse_connect.driver import httputil
from dbgpt.storage.schema import DBType
from sqlalchemy import (
MetaData,
)
class ClickhouseConnect(RDBMSDatabase):
@@ -20,6 +27,24 @@ class ClickhouseConnect(RDBMSDatabase):
"""db dialect"""
db_dialect: str = "clickhouse"
client: Any = None
def __init__(self, client, **kwargs):
self.client = client
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
self._metadata = MetaData()
@classmethod
def from_uri_db(
cls,
@@ -31,21 +56,75 @@ class ClickhouseConnect(RDBMSDatabase):
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12)
client = clickhouse_connect.get_client(
host=host,
user=user,
password=pwd,
port=port,
connect_timeout=15,
database=db_name,
settings={"distributed_ddl_task_timeout": 300},
pool_mgr=big_pool_mgr,
)
return cls.from_uri(db_url, engine_args, **kwargs)
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return ""
cls.client = client
return cls(client, **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
pass
def get_table_names(self):
"""Get all table names."""
session = self.client
with session.query_row_block_stream("SHOW TABLES") as stream:
tables = [row[0] for block in stream for row in block]
return tables
def get_indexes(self, table_name: str) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name (str): table name
Returns:
indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
session = self.client
_query_sql = f"""
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
"""
with session.query_row_block_stream(_query_sql) as stream:
indexes = [block for block in stream]
return [
{"name": "primary_key", "column_names": column_names.split(",")}
for table, column_names in indexes[0]
]
@property
def table_info(self) -> str:
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
# TODO:
pass
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
ans = ans[0][0]
result = self.client.command(text(f"SHOW CREATE TABLE {table_name}"))
ans = result
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
ans = re.sub(
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
@@ -53,18 +132,32 @@ class ClickhouseConnect(RDBMSDatabase):
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
return ans
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns.
Args:
table_name (str): str
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
fields = self.get_fields(table_name)
return [
{"name": name, "comment": comment, "type": column_type}
for name, column_type, _, _, comment in fields[0]
]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'".format(
table_name
)
)
session = self.client
_query_sql = f"""
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
""".format(
table_name
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
with session.query_row_block_stream(_query_sql) as stream:
fields = [block for block in stream]
return fields
def get_users(self):
return []
@@ -80,31 +173,187 @@ class ClickhouseConnect(RDBMSDatabase):
return "UTF-8"
def get_database_list(self):
return []
session = self.client
with session.command("SHOW DATABASES") as stream:
databases = [
row[0]
for block in stream
for row in block
if row[0]
not in ("INFORMATION_SCHEMA", "system", "default", "information_schema")
]
return databases
def get_database_names(self):
return []
return self.get_database_list()
def get_table_comments(self, db_name):
session = self._db_sessions()
cursor = session.execute(
text(
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
db_name
)
)
def run(self, command: str, fetch: str = "all") -> List:
# TODO need to be implemented
print("SQL:" + command)
if not command or len(command) < 0:
return []
_, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self._query(command, fetch)
else:
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = self.client.command(command)
if cursor.written_rows:
result = cursor.result_rows
field_names = result.column_names
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
return result
else:
return self.get_simple_fields(table_name)
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_current_db_name(self):
return self.client.database
def get_table_comments(self, db_name: str):
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
db_name
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
with session.query_row_block_stream(_query_sql) as stream:
table_comments = [row for block in stream for row in block]
return table_comments
def get_table_comment(self, table_name: str) -> Dict:
"""Get table comment.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
self.client.database
)
with session.query_row_block_stream(_query_sql) as stream:
table_comments = [row for block in stream for row in block]
return [{"text": comment} for table_name, comment in table_comments][0]
def get_column_comments(self, db_name, table_name):
session = self.client
_query_sql = f"""
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
""".format(
db_name, table_name
)
with session.query_row_block_stream(_query_sql) as stream:
column_comments = [row for block in stream for row in block]
return column_comments
def table_simple_info(self):
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
_sql = f"""
select concat(TABLE_NAME, \'(\' , arrayStringConcat(groupArray(column_name),\'-\'), \')\') as schema_info
from information_schema.COLUMNS where table_schema=\'{self.get_current_db_name()}\' group by TABLE_NAME; """
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
return results
_sql = f"""
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
FROM information_schema.COLUMNS
WHERE table_schema = '{self.get_current_db_name()}'
GROUP BY TABLE_NAME
"""
with self.client.query_row_block_stream(_sql) as stream:
return [row[0] for block in stream for row in block]
def _write(self, write_sql: str):
"""write data
Args:
write_sql (str): sql string
"""
# TODO need to be implemented
print(f"Write[{write_sql}]")
result = self.client.command(write_sql)
print(f"SQL[{write_sql}], result:{result.written_rows}")
def _query(self, query: str, fetch: str = "all"):
"""Query data from clickhouse
Args:
query (str): sql string
fetch (str, optional): "one" or "all". Defaults to "all".
Raises:
ValueError: Error
Returns:
_type_: List<Result>
"""
# TODO need to be implemented
print(f"Query[{query}]")
if not query:
return []
cursor = self.client.query(query)
if fetch == "all":
result = cursor.result_rows
elif fetch == "one":
result = cursor.first_row
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = cursor.column_names
result.insert(0, field_names)
return result
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
if sql_type == "CREATE":
table_name = self._extract_table_name_from_ddl(parsed)
else:
table_name = parsed.get_name()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
return parsed, ttype, sql_type, table_name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema
_schema = (
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=_schema)
+ (
self._inspector.get_view_names(schema=_schema)
if self.view_support
else []
)
)
return self._all_tables