Files
DB-GPT/dbgpt/datasource/rdbms/conn_clickhouse.py

360 lines
12 KiB
Python

import re
import sqlparse
from typing import List, Optional, Any, Iterable, Dict
from sqlalchemy import text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from dbgpt.storage.schema import DBType
from sqlalchemy import (
MetaData,
)
class ClickhouseConnect(RDBMSDatabase):
"""Connect Clickhouse Database fetch MetaData
Args:
Usage:
"""
"""db type"""
db_type: str = "clickhouse"
"""db driver"""
driver: str = "clickhouse"
"""db dialect"""
db_dialect: str = "clickhouse"
client: Any = None
def __init__(self, client, **kwargs):
self.client = client
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
self._metadata = MetaData()
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
import clickhouse_connect
from clickhouse_connect.driver import httputil
# Lazy import
big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12)
client = clickhouse_connect.get_client(
host=host,
user=user,
password=pwd,
port=port,
connect_timeout=15,
database=db_name,
settings={"distributed_ddl_task_timeout": 300},
pool_mgr=big_pool_mgr,
)
cls.client = client
return cls(client, **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
pass
def get_table_names(self):
"""Get all table names."""
session = self.client
with session.query_row_block_stream("SHOW TABLES") as stream:
tables = [row[0] for block in stream for row in block]
return tables
def get_indexes(self, table_name: str) -> List[Dict]:
"""Get table indexes about specified table.
Args:
table_name (str): table name
Returns:
indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}]
"""
session = self.client
_query_sql = f"""
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
"""
with session.query_row_block_stream(_query_sql) as stream:
indexes = [block for block in stream]
return [
{"name": "primary_key", "column_names": column_names.split(",")}
for table, column_names in indexes[0]
]
@property
def table_info(self) -> str:
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
# TODO:
pass
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
result = self.client.command(text(f"SHOW CREATE TABLE {table_name}"))
ans = result
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
ans = re.sub(
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
)
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
return ans
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns.
Args:
table_name (str): str
Returns:
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
fields = self.get_fields(table_name)
return [
{"name": name, "comment": comment, "type": column_type}
for name, column_type, _, _, comment in fields[0]
]
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self.client
_query_sql = f"""
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
""".format(
table_name
)
with session.query_row_block_stream(_query_sql) as stream:
fields = [block for block in stream]
return fields
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_database_list(self):
session = self.client
with session.command("SHOW DATABASES") as stream:
databases = [
row[0]
for block in stream
for row in block
if row[0]
not in ("INFORMATION_SCHEMA", "system", "default", "information_schema")
]
return databases
def get_database_names(self):
return self.get_database_list()
def run(self, command: str, fetch: str = "all") -> List:
# TODO need to be implemented
print("SQL:" + command)
if not command or len(command) < 0:
return []
_, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self._query(command, fetch)
else:
self._write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self._query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
cursor = self.client.command(command)
if cursor.written_rows:
result = cursor.result_rows
field_names = result.column_names
result = list(result)
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
# return self._query(f"SHOW COLUMNS FROM {table_name}")
return self.get_simple_fields(table_name)
return result
else:
return self.get_simple_fields(table_name)
def get_simple_fields(self, table_name):
"""Get column fields about specified table."""
return self._query(f"SHOW COLUMNS FROM {table_name}")
def get_current_db_name(self):
return self.client.database
def get_table_comments(self, db_name: str):
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
db_name
)
with session.query_row_block_stream(_query_sql) as stream:
table_comments = [row for block in stream for row in block]
return table_comments
def get_table_comment(self, table_name: str) -> Dict:
"""Get table comment.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
session = self.client
_query_sql = f"""
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
self.client.database
)
with session.query_row_block_stream(_query_sql) as stream:
table_comments = [row for block in stream for row in block]
return [{"text": comment} for table_name, comment in table_comments][0]
def get_column_comments(self, db_name, table_name):
session = self.client
_query_sql = f"""
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
""".format(
db_name, table_name
)
with session.query_row_block_stream(_query_sql) as stream:
column_comments = [row for block in stream for row in block]
return column_comments
def table_simple_info(self):
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
_sql = f"""
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
FROM information_schema.COLUMNS
WHERE table_schema = '{self.get_current_db_name()}'
GROUP BY TABLE_NAME
"""
with self.client.query_row_block_stream(_sql) as stream:
return [row[0] for block in stream for row in block]
def _write(self, write_sql: str):
"""write data
Args:
write_sql (str): sql string
"""
# TODO need to be implemented
print(f"Write[{write_sql}]")
result = self.client.command(write_sql)
print(f"SQL[{write_sql}], result:{result.written_rows}")
def _query(self, query: str, fetch: str = "all"):
"""Query data from clickhouse
Args:
query (str): sql string
fetch (str, optional): "one" or "all". Defaults to "all".
Raises:
ValueError: Error
Returns:
_type_: List<Result>
"""
# TODO need to be implemented
print(f"Query[{query}]")
if not query:
return []
cursor = self.client.query(query)
if fetch == "all":
result = cursor.result_rows
elif fetch == "one":
result = cursor.first_row
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = cursor.column_names
result.insert(0, field_names)
return result
def __sql_parse(self, sql):
sql = sql.strip()
parsed = sqlparse.parse(sql)[0]
sql_type = parsed.get_type()
if sql_type == "CREATE":
table_name = self._extract_table_name_from_ddl(parsed)
else:
table_name = parsed.get_name()
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
ttype = first_token.ttype
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
return parsed, ttype, sql_type, table_name
def _sync_tables_from_db(self) -> Iterable[str]:
"""Read table information from database"""
# TODO Use a background thread to refresh periodically
# SQL will raise error with schema
_schema = (
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=_schema)
+ (
self._inspector.get_view_names(schema=_schema)
if self.view_support
else []
)
)
return self._all_tables