mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-14 13:40:54 +00:00
refactor: Refactor datasource module (#1309)
This commit is contained in:
@@ -1,18 +1,20 @@
|
||||
"""Clickhouse connector."""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import sqlparse
|
||||
from sqlalchemy import MetaData, text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.storage.schema import DBType
|
||||
|
||||
from .base import RDBMSConnector
|
||||
|
||||
class ClickhouseConnect(RDBMSDatabase):
|
||||
"""Connect Clickhouse Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClickhouseConnector(RDBMSConnector):
|
||||
"""Clickhouse connector."""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "clickhouse"
|
||||
@@ -24,6 +26,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
client: Any = None
|
||||
|
||||
def __init__(self, client, **kwargs):
|
||||
"""Create a new ClickhouseConnector from client."""
|
||||
self.client = client
|
||||
|
||||
self._all_tables = set()
|
||||
@@ -49,7 +52,8 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
) -> "ClickhouseConnector":
|
||||
"""Create a new ClickhouseConnector from host, port, user, pwd, db_name."""
|
||||
import clickhouse_connect
|
||||
from clickhouse_connect.driver import httputil
|
||||
|
||||
@@ -70,11 +74,6 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
cls.client = client
|
||||
return cls(client, **kwargs)
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
pass
|
||||
|
||||
def get_table_names(self):
|
||||
"""Get all table names."""
|
||||
session = self.client
|
||||
@@ -85,6 +84,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
def get_indexes(self, table_name: str) -> List[Dict]:
|
||||
"""Get table indexes about specified table.
|
||||
|
||||
Args:
|
||||
table_name (str): table name
|
||||
Returns:
|
||||
@@ -93,10 +93,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT name AS table, primary_key, from system.tables where database ='{self.client.database}' and table = '{table_name}'
|
||||
SELECT name AS table, primary_key, from system.tables where
|
||||
database ='{self.client.database}' and table = '{table_name}'
|
||||
"""
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
indexes = [block for block in stream]
|
||||
indexes = [block for block in stream] # noqa
|
||||
return [
|
||||
{"name": "primary_key", "column_names": column_names.split(",")}
|
||||
for table, column_names in indexes[0]
|
||||
@@ -104,6 +105,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Get table info."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
@@ -117,7 +119,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
demonstrated in the paper.
|
||||
"""
|
||||
# TODO:
|
||||
pass
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
@@ -133,11 +135,14 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
def get_columns(self, table_name: str) -> List[Dict]:
|
||||
"""Get columns.
|
||||
|
||||
Args:
|
||||
table_name (str): str
|
||||
Returns:
|
||||
columns: List[Dict], which contains name: str, type: str, default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '', 'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
List[Dict], which contains name: str, type: str,
|
||||
default_expression: str, is_in_primary_key: bool, comment: str
|
||||
eg:[{'name': 'id', 'type': 'UInt64', 'default_expression': '',
|
||||
'is_in_primary_key': True, 'comment': 'id'}, ...]
|
||||
"""
|
||||
fields = self.get_fields(table_name)
|
||||
return [
|
||||
@@ -150,18 +155,21 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'
|
||||
SELECT name, type, default_expression, is_in_primary_key, comment
|
||||
from system.columns where table='{table_name}'
|
||||
""".format(
|
||||
table_name
|
||||
)
|
||||
with session.query_row_block_stream(_query_sql) as stream:
|
||||
fields = [block for block in stream]
|
||||
fields = [block for block in stream] # noqa
|
||||
return fields
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grants."""
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
@@ -169,9 +177,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
def get_database_names(self):
|
||||
"""Get database names."""
|
||||
session = self.client
|
||||
|
||||
with session.command("SHOW DATABASES") as stream:
|
||||
@@ -184,12 +194,10 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
]
|
||||
return databases
|
||||
|
||||
def get_database_names(self):
|
||||
return self.get_database_list()
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute sql command."""
|
||||
# TODO need to be implemented
|
||||
print("SQL:" + command)
|
||||
logger.info("SQL:" + command)
|
||||
if not command or len(command) < 0:
|
||||
return []
|
||||
_, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
@@ -199,10 +207,12 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
else:
|
||||
self._write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
logger.info(f"write result query:{select_sql}")
|
||||
return self._query(select_sql)
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
logger.info(
|
||||
"DDL execution determines whether to enable through configuration "
|
||||
)
|
||||
|
||||
cursor = self.client.command(command)
|
||||
|
||||
@@ -212,7 +222,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
logger.info("DDL Result:" + str(result))
|
||||
if not result:
|
||||
# return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return self.get_simple_fields(table_name)
|
||||
@@ -225,13 +235,16 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return self._query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def get_current_db_name(self):
|
||||
"""Get current database name."""
|
||||
return self.client.database
|
||||
|
||||
def get_table_comments(self, db_name: str):
|
||||
"""Get table comments."""
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
SELECT table, comment FROM system.tables WHERE database = '{db_name}'
|
||||
""".format(
|
||||
db_name
|
||||
)
|
||||
|
||||
@@ -241,6 +254,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
def get_table_comment(self, table_name: str) -> Dict:
|
||||
"""Get table comment.
|
||||
|
||||
Args:
|
||||
table_name (str): table name
|
||||
Returns:
|
||||
@@ -249,7 +263,9 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
session = self.client
|
||||
|
||||
_query_sql = f"""
|
||||
SELECT table, comment FROM system.tables WHERE database = '{self.client.database}'and table = '{table_name}'""".format(
|
||||
SELECT table, comment FROM system.tables WHERE
|
||||
database = '{self.client.database}'and table = '{table_name}'
|
||||
""".format(
|
||||
self.client.database
|
||||
)
|
||||
|
||||
@@ -258,9 +274,11 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return [{"text": comment} for table_name, comment in table_comments][0]
|
||||
|
||||
def get_column_comments(self, db_name, table_name):
|
||||
"""Get column comments."""
|
||||
session = self.client
|
||||
_query_sql = f"""
|
||||
select name column, comment from system.columns where database='{db_name}' and table='{table_name}'
|
||||
select name column, comment from system.columns where database='{db_name}'
|
||||
and table='{table_name}'
|
||||
""".format(
|
||||
db_name, table_name
|
||||
)
|
||||
@@ -270,10 +288,13 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return column_comments
|
||||
|
||||
def table_simple_info(self):
|
||||
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
|
||||
"""Get table simple info."""
|
||||
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray
|
||||
# instead; and quotes need to be escaped
|
||||
|
||||
_sql = f"""
|
||||
SELECT concat(TABLE_NAME, '(', arrayStringConcat(groupArray(column_name), '-'), ')') AS schema_info
|
||||
SELECT concat(TABLE_NAME, '(', arrayStringConcat(
|
||||
groupArray(column_name), '-'), ')') AS schema_info
|
||||
FROM INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE table_schema = '{self.get_current_db_name()}'
|
||||
GROUP BY TABLE_NAME
|
||||
@@ -282,18 +303,18 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
return [row[0] for block in stream for row in block]
|
||||
|
||||
def _write(self, write_sql: str):
|
||||
"""write data
|
||||
"""Execute write sql.
|
||||
|
||||
Args:
|
||||
write_sql (str): sql string
|
||||
"""
|
||||
# TODO need to be implemented
|
||||
print(f"Write[{write_sql}]")
|
||||
logger.info(f"Write[{write_sql}]")
|
||||
result = self.client.command(write_sql)
|
||||
print(f"SQL[{write_sql}], result:{result.written_rows}")
|
||||
logger.info(f"SQL[{write_sql}], result:{result.written_rows}")
|
||||
|
||||
def _query(self, query: str, fetch: str = "all"):
|
||||
"""Query data from clickhouse
|
||||
"""Query data from clickhouse.
|
||||
|
||||
Args:
|
||||
query (str): sql string
|
||||
@@ -306,7 +327,7 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
_type_: List<Result>
|
||||
"""
|
||||
# TODO need to be implemented
|
||||
print(f"Query[{query}]")
|
||||
logger.info(f"Query[{query}]")
|
||||
|
||||
if not query:
|
||||
return []
|
||||
@@ -334,11 +355,13 @@ class ClickhouseConnect(RDBMSDatabase):
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||
logger.info(
|
||||
f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}"
|
||||
)
|
||||
return parsed, ttype, sql_type, table_name
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
"""Read table information from database"""
|
||||
"""Read table information from database."""
|
||||
# TODO Use a background thread to refresh periodically
|
||||
|
||||
# SQL will raise error with schema
|
||||
|
Reference in New Issue
Block a user