mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
fix: clickhouse connect error fix (#958)
Co-authored-by: Magic <magic@B-4TMH9N3X-2120.local> Co-authored-by: aries_ckt <916701291@qq.com>
This commit is contained in:
@@ -1,10 +1,17 @@
|
||||
import re
|
||||
from typing import Optional, Any
|
||||
import sqlparse
|
||||
import clickhouse_connect
|
||||
from typing import List, Optional, Any, Iterable, Dict
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from clickhouse_connect.driver import httputil
|
||||
from dbgpt.storage.schema import DBType
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
)
|
||||
|
||||
|
||||
class ClickhouseConnect(RDBMSDatabase):
|
||||
@@ -20,6 +27,24 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
"""db dialect"""
|
||||
db_dialect: str = "clickhouse"
|
||||
|
||||
client: Any = None
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
self.client = client
|
||||
|
||||
self._all_tables = set()
|
||||
self.view_support = False
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = set()
|
||||
self._indexes_in_table_info = set()
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
|
||||
self._metadata = MetaData()
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
@@ -31,21 +56,75 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
big_pool_mgr = httputil.get_pool_manager(maxsize=16, num_pools=12)
|
||||
client = clickhouse_connect.get_client(
|
||||
host=host,
|
||||
user=user,
|
||||
password=pwd,
|
||||
port=port,
|
||||
connect_timeout=15,
|
||||
database=db_name,
|
||||
settings={"distributed_ddl_task_timeout": 300},
|
||||
pool_mgr=big_pool_mgr,
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
cls.client = client
|
||||
return cls(client, **kwargs)
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
pass
|
||||
|
||||
def get_table_names(self):
|
||||
"""Get all table names."""
|
||||
session = self.client
|
||||
|
||||
with session.query_row_block_stream("SHOW TABLES") as stream:
|
||||
tables = [row[0] for block in stream for row in block]
|
||||
return tables
|
||||
|
||||
def get_indexes(self, table_name: str) -> List[Dict]:
|
||||
"""Get table indexes about specified table.
|
||||
Args:
|
||||
table_name (str): table name
|
||||
Returns:
|
||||
indexes: List[Dict], eg:[{'name': 'idx_key', 'column_names': ['id']}]
|
||||
"""
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
|
||||
"""
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
indexes = [block for block in stream]
|
||||
return [
|
||||
{"name": "primary_key", "column_names": column_names.split(",")}
|
||||
for table, column_names in indexes[0]
|
||||
]
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables.
|
||||
|
||||
Follows best practices as specified in: Rajkumar et al, 2022
|
||||
(https://arxiv.org/abs/2204.00498)
|
||||
|
||||
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||
appended to each table description. This can increase performance as
|
||||
demonstrated in the paper.
|
||||
"""
|
||||
# TODO:
|
||||
pass
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
ans = cursor.fetchall()
|
||||
ans = ans[0][0]
|
||||
result = self.client.command(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
|
||||
ans = result
|
||||
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
|
||||
ans = re.sub(
|
||||
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
|
||||
@@ -53,18 +132,32 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
|
||||
return ans
|
||||
|
||||
def get_columns(self, table_name: str) -> List[Dict]:
|
||||
"""Get columns.
|
||||
Args:
|
||||
table_name (str): str
|
||||
Returns:
|
||||
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
"""
|
||||
fields = self.get_fields(table_name)
|
||||
return [
|
||||
{"name": name, "comment": comment, "type": column_type}
|
||||
for name, column_type, _, _, comment in fields[0]
|
||||
]
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
)
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
|
||||
""".format(
|
||||
table_name
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
fields = [block for block in stream]
|
||||
return fields
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
@@ -80,31 +173,187 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
session = self.client
|
||||
|
||||
with session.command("SHOW DATABASES") as stream:
|
||||
databases = [
|
||||
row[0]
|
||||
for block in stream
|
||||
for row in block
|
||||
if row[0]
|
||||
not in ("INFORMATION_SCHEMA", "system", "default", "information_schema")
|
||||
]
|
||||
return databases
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
return self.get_database_list()
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
# TODO need to be implemented
|
||||
print("SQL:" + command)
|
||||
if not command or len(command) < 0:
|
||||
return []
|
||||
_, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self._query(command, fetch)
|
||||
else:
|
||||
self._write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self._query(select_sql)
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
|
||||
cursor = self.client.command(command)
|
||||
|
||||
if cursor.written_rows:
|
||||
result = cursor.result_rows
|
||||
field_names = result.column_names
|
||||
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.get_simple_fields(table_name)
|
||||
return result
|
||||
else:
|
||||
return self.get_simple_fields(table_name)
|
||||
|
||||
def get_simple_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def get_current_db_name(self):
|
||||
return self.client.database
|
||||
|
||||
def get_table_comments(self, db_name: str):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
table_comments = [row for block in stream for row in block]
|
||||
return table_comments
|
||||
|
||||
def get_table_comment(self, table_name: str) -> Dict:
|
||||
"""Get table comment.
|
||||
Args:
|
||||
table_name (str): table name
|
||||
Returns:
|
||||
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
|
||||
"""
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
|
||||
self.client.database
|
||||
)
|
||||
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
table_comments = [row for block in stream for row in block]
|
||||
return [{"text": comment} for table_name, comment in table_comments][0]
|
||||
|
||||
def get_column_comments(self, db_name, table_name):
|
||||
session = self.client
|
||||
_query_sql = f"""
|
||||
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
|
||||
""".format(
|
||||
db_name, table_name
|
||||
)
|
||||
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
column_comments = [row for block in stream for row in block]
|
||||
return column_comments
|
||||
|
||||
def table_simple_info(self):
|
||||
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
|
||||
_sql = f"""
|
||||
select concat(TABLE_NAME, \'(\' , arrayStringConcat(groupArray(column_name),\'-\'), \')\') as schema_info
|
||||
from information_schema.COLUMNS where table_schema=\'{self.get_current_db_name()}\' group by TABLE_NAME; """
|
||||
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
_sql = f"""
|
||||
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
|
||||
FROM information_schema.COLUMNS
|
||||
WHERE table_schema = '{self.get_current_db_name()}'
|
||||
GROUP BY TABLE_NAME
|
||||
"""
|
||||
with self.client.query_row_block_stream(_sql) as stream:
|
||||
return [row[0] for block in stream for row in block]
|
||||
|
||||
def _write(self, write_sql: str):
|
||||
"""write data
|
||||
|
||||
Args:
|
||||
write_sql (str): sql string
|
||||
"""
|
||||
# TODO need to be implemented
|
||||
print(f"Write[{write_sql}]")
|
||||
result = self.client.command(write_sql)
|
||||
print(f"SQL[{write_sql}], result:{result.written_rows}")
|
||||
|
||||
def _query(self, query: str, fetch: str = "all"):
|
||||
"""Query data from clickhouse
|
||||
|
||||
Args:
|
||||
query (str): sql string
|
||||
fetch (str, optional): "one" or "all". Defaults to "all".
|
||||
|
||||
Raises:
|
||||
ValueError: Error
|
||||
|
||||
Returns:
|
||||
_type_: List<Result>
|
||||
"""
|
||||
# TODO need to be implemented
|
||||
print(f"Query[{query}]")
|
||||
|
||||
if not query:
|
||||
return []
|
||||
|
||||
cursor = self.client.query(query)
|
||||
if fetch == "all":
|
||||
result = cursor.result_rows
|
||||
elif fetch == "one":
|
||||
result = cursor.first_row
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
|
||||
field_names = cursor.column_names
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def __sql_parse(self, sql):
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
if sql_type == "CREATE":
|
||||
table_name = self._extract_table_name_from_ddl(parsed)
|
||||
else:
|
||||
table_name = parsed.get_name()
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||
return parsed, ttype, sql_type, table_name
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
"""Read table information from database"""
|
||||
# TODO Use a background thread to refresh periodically
|
||||
|
||||
# SQL will raise error with schema
|
||||
_schema = (
|
||||
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
|
||||
)
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=_schema)
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=_schema)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
return self._all_tables
|
||||
|
Reference in New Issue
Block a user