mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 06:17:14 +00:00
388 lines
12 KiB
Python
388 lines
12 KiB
Python
"""Clickhouse connector."""
|
|
import logging
|
|
import re
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
|
import sqlparse
|
|
from sqlalchemy import MetaData, text
|
|
|
|
from dbgpt.storage.schema import DBType
|
|
|
|
from .base import RDBMSConnector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ClickhouseConnector(RDBMSConnector):
|
|
"""Clickhouse connector."""
|
|
|
|
"""db type"""
|
|
db_type: str = "clickhouse"
|
|
"""db driver"""
|
|
driver: str = "clickhouse"
|
|
"""db dialect"""
|
|
db_dialect: str = "clickhouse"
|
|
|
|
client: Any = None
|
|
|
|
def __init__(self, client, **kwargs):
|
|
"""Create a new ClickhouseConnector from client."""
|
|
self.client = client
|
|
|
|
self._all_tables = set()
|
|
self.view_support = False
|
|
self._usable_tables = set()
|
|
self._include_tables = set()
|
|
self._ignore_tables = set()
|
|
self._custom_table_info = set()
|
|
self._indexes_in_table_info = set()
|
|
self._usable_tables = set()
|
|
self._usable_tables = set()
|
|
self._sample_rows_in_table_info = set()
|
|
|
|
self._metadata = MetaData()
|
|
|
|
@classmethod
|
|
def from_uri_db(
|
|
cls,
|
|
host: str,
|
|
port: int,
|
|
user: str,
|
|
pwd: str,
|
|
db_name: str,
|
|
engine_args: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> "ClickhouseConnector":
|
|
"""Create a new ClickhouseConnector from host, port, user, pwd, db_name."""
|
|
import clickhouse_connect
|
|
from clickhouse_connect.driver import httputil
|
|
|
|
# Lazy import
|
|
|
|
big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12)
|
|
client = clickhouse_connect.get_client(
|
|
host=host,
|
|
user=user,
|
|
password=pwd,
|
|
port=port,
|
|
connect_timeout=15,
|
|
database=db_name,
|
|
settings={"distributed_ddl_task_timeout": 300},
|
|
pool_mgr=big_pool_mgr,
|
|
)
|
|
|
|
cls.client = client
|
|
return cls(client, **kwargs)
|
|
|
|
def get_table_names(self):
|
|
"""Get all table names."""
|
|
session = self.client
|
|
|
|
with session.query_row_block_stream("SHOW TABLES") as stream:
|
|
tables = [row[0] for block in stream for row in block]
|
|
return tables
|
|
|
|
def get_indexes(self, table_name: str) -> List[Dict]:
|
|
"""Get table indexes about specified table.
|
|
|
|
Args:
|
|
table_name (str): table name
|
|
Returns:
|
|
indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}]
|
|
"""
|
|
session = self.client
|
|
|
|
_query_sql = f"""
|
|
SELECT name AS table, primary_key, from system.tables where
|
|
database ='{self.client.database}' and table = '{table_name}'
|
|
"""
|
|
with session.query_row_block_stream(_query_sql) as stream:
|
|
indexes = [block for block in stream] # noqa
|
|
return [
|
|
{"name": "primary_key", "column_names": column_names.split(",")}
|
|
for table, column_names in indexes[0]
|
|
]
|
|
|
|
@property
|
|
def table_info(self) -> str:
|
|
"""Get table info."""
|
|
return self.get_table_info()
|
|
|
|
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
|
"""Get information about specified tables.
|
|
|
|
Follows best practices as specified in: Rajkumar et al, 2022
|
|
(https://arxiv.org/abs/2204.00498)
|
|
|
|
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
|
appended to each table description. This can increase performance as
|
|
demonstrated in the paper.
|
|
"""
|
|
# TODO:
|
|
return ""
|
|
|
|
def get_show_create_table(self, table_name):
|
|
"""Get table show create table about specified table."""
|
|
result = self.client.command(text(f"SHOW CREATE TABLE {table_name}"))
|
|
|
|
ans = result
|
|
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
|
|
ans = re.sub(
|
|
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
|
|
)
|
|
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
|
|
return ans
|
|
|
|
def get_columns(self, table_name: str) -> List[Dict]:
|
|
"""Get columns.
|
|
|
|
Args:
|
|
table_name (str): str
|
|
Returns:
|
|
List[Dict], which contains name: str, type: str,
|
|
default_expression: str, is_in_primary_key: bool, comment: str
|
|
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
|
|
'is_in_primary_key': True, 'comment': 'id'}, ...]
|
|
"""
|
|
fields = self.get_fields(table_name)
|
|
return [
|
|
{"name": name, "comment": comment, "type": column_type}
|
|
for name, column_type, _, _, comment in fields[0]
|
|
]
|
|
|
|
@property
|
|
def dialect(self) -> str:
|
|
"""Return string representation of dialect to use."""
|
|
return ""
|
|
|
|
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
|
|
"""Get column fields about specified table."""
|
|
session = self.client
|
|
_query_sql = f"""
|
|
SELECT name, type, default_expression, is_in_primary_key, comment
|
|
from system.columns where table='{table_name}'
|
|
""".format(
|
|
table_name
|
|
)
|
|
if db_name is not None:
|
|
_query_sql += f" AND database='{db_name}'"
|
|
with session.query_row_block_stream(_query_sql) as stream:
|
|
fields = [block for block in stream] # noqa
|
|
return fields
|
|
|
|
def get_users(self):
|
|
"""Get user info."""
|
|
return []
|
|
|
|
def get_grants(self):
|
|
"""Get grants."""
|
|
return []
|
|
|
|
def get_collation(self):
|
|
"""Get collation."""
|
|
return "UTF-8"
|
|
|
|
def get_charset(self):
|
|
"""Get character_set."""
|
|
return "UTF-8"
|
|
|
|
def get_database_names(self):
|
|
"""Get database names."""
|
|
session = self.client
|
|
|
|
with session.command("SHOW DATABASES") as stream:
|
|
databases = [
|
|
row[0]
|
|
for block in stream
|
|
for row in block
|
|
if row[0]
|
|
not in ("INFORMATION_SCHEMA", "system", "default", "information_schema")
|
|
]
|
|
return databases
|
|
|
|
def run(self, command: str, fetch: str = "all") -> List:
|
|
"""Execute sql command."""
|
|
# TODO need to be implemented
|
|
logger.info("SQL:" + command)
|
|
if not command or len(command) < 0:
|
|
return []
|
|
_, ttype, sql_type, table_name = self.__sql_parse(command)
|
|
if ttype == sqlparse.tokens.DML:
|
|
if sql_type == "SELECT":
|
|
return self._query(command, fetch)
|
|
else:
|
|
self._write(command)
|
|
select_sql = self.convert_sql_write_to_select(command)
|
|
logger.info(f"write result query:{select_sql}")
|
|
return self._query(select_sql)
|
|
else:
|
|
logger.info(
|
|
"DDL execution determines whether to enable through configuration "
|
|
)
|
|
|
|
cursor = self.client.command(command)
|
|
|
|
if cursor.written_rows:
|
|
result = cursor.result_rows
|
|
field_names = result.column_names
|
|
|
|
result = list(result)
|
|
result.insert(0, field_names)
|
|
logger.info("DDL Result:" + str(result))
|
|
if not result:
|
|
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
|
return self.get_simple_fields(table_name)
|
|
return result
|
|
else:
|
|
return self.get_simple_fields(table_name)
|
|
|
|
def get_simple_fields(self, table_name):
|
|
"""Get column fields about specified table."""
|
|
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
|
|
|
def get_current_db_name(self):
|
|
"""Get current database name."""
|
|
return self.client.database
|
|
|
|
def get_table_comments(self, db_name: str):
|
|
"""Get table comments."""
|
|
session = self.client
|
|
|
|
_query_sql = f"""
|
|
SELECT table, comment FROM system.tables WHERE database = '{db_name}'
|
|
""".format(
|
|
db_name
|
|
)
|
|
|
|
with session.query_row_block_stream(_query_sql) as stream:
|
|
table_comments = [row for block in stream for row in block]
|
|
return table_comments
|
|
|
|
def get_table_comment(self, table_name: str) -> Dict:
|
|
"""Get table comment.
|
|
|
|
Args:
|
|
table_name (str): table name
|
|
Returns:
|
|
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
|
|
"""
|
|
session = self.client
|
|
|
|
_query_sql = f"""
|
|
SELECT table, comment FROM system.tables WHERE
|
|
database = '{self.client.database}'and table = '{table_name}'
|
|
""".format(
|
|
self.client.database
|
|
)
|
|
|
|
with session.query_row_block_stream(_query_sql) as stream:
|
|
table_comments = [row for block in stream for row in block]
|
|
return [{"text": comment} for table_name, comment in table_comments][0]
|
|
|
|
def get_column_comments(self, db_name, table_name):
|
|
"""Get column comments."""
|
|
session = self.client
|
|
_query_sql = f"""
|
|
select name column, comment from system.columns where database='{db_name}'
|
|
and table='{table_name}'
|
|
""".format(
|
|
db_name, table_name
|
|
)
|
|
|
|
with session.query_row_block_stream(_query_sql) as stream:
|
|
column_comments = [row for block in stream for row in block]
|
|
return column_comments
|
|
|
|
def table_simple_info(self):
|
|
"""Get table simple info."""
|
|
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray
|
|
# instead; and quotes need to be escaped
|
|
|
|
_sql = f"""
|
|
SELECT concat(TABLE_NAME, '(', arrayStringConcat(
|
|
groupArray(column_name), '-'), ')') AS schema_info
|
|
FROM INFORMATION_SCHEMA.COLUMNS
|
|
WHERE table_schema = '{self.get_current_db_name()}'
|
|
GROUP BY TABLE_NAME
|
|
"""
|
|
with self.client.query_row_block_stream(_sql) as stream:
|
|
return [row[0] for block in stream for row in block]
|
|
|
|
def _write(self, write_sql: str):
|
|
"""Execute write sql.
|
|
|
|
Args:
|
|
write_sql (str): sql string
|
|
"""
|
|
# TODO need to be implemented
|
|
logger.info(f"Write[{write_sql}]")
|
|
result = self.client.command(write_sql)
|
|
logger.info(f"SQL[{write_sql}], result:{result.written_rows}")
|
|
|
|
def _query(self, query: str, fetch: str = "all"):
|
|
"""Query data from clickhouse.
|
|
|
|
Args:
|
|
query (str): sql string
|
|
fetch (str, optional): "one" or "all". Defaults to "all".
|
|
|
|
Raises:
|
|
ValueError: Error
|
|
|
|
Returns:
|
|
_type_: List<Result>
|
|
"""
|
|
# TODO need to be implemented
|
|
logger.info(f"Query[{query}]")
|
|
|
|
if not query:
|
|
return []
|
|
|
|
cursor = self.client.query(query)
|
|
if fetch == "all":
|
|
result = cursor.result_rows
|
|
elif fetch == "one":
|
|
result = cursor.first_row
|
|
else:
|
|
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
|
|
|
field_names = cursor.column_names
|
|
result.insert(0, field_names)
|
|
return result
|
|
|
|
def __sql_parse(self, sql):
|
|
sql = sql.strip()
|
|
parsed = sqlparse.parse(sql)[0]
|
|
sql_type = parsed.get_type()
|
|
if sql_type == "CREATE":
|
|
table_name = self._extract_table_name_from_ddl(parsed)
|
|
else:
|
|
table_name = parsed.get_name()
|
|
|
|
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
|
ttype = first_token.ttype
|
|
logger.info(
|
|
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
|
|
)
|
|
return parsed, ttype, sql_type, table_name
|
|
|
|
def _sync_tables_from_db(self) -> Iterable[str]:
|
|
"""Read table information from database."""
|
|
# TODO Use a background thread to refresh periodically
|
|
|
|
# SQL will raise error with schema
|
|
_schema = (
|
|
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
|
|
)
|
|
# including view support by adding the views as well as tables to the all
|
|
# tables list if view_support is True
|
|
self._all_tables = set(
|
|
self._inspector.get_table_names(schema=_schema)
|
|
+ (
|
|
self._inspector.get_view_names(schema=_schema)
|
|
if self.view_support
|
|
else []
|
|
)
|
|
)
|
|
return self._all_tables
|