feat:add clickouse connection

This commit is contained in:
aries_ckt 2023-08-22 13:25:23 +08:00
parent bf3e52aa32
commit 5ed35acce8
4 changed files with 143 additions and 1 deletions

View File

@ -29,6 +29,7 @@ class DBType(Enum):
Oracle = DbInfo("oracle")
MSSQL = DbInfo("mssql")
Postgresql = DbInfo("postgresql")
Clickhouse = DbInfo("clickhouse")
def value(self):
return self._value_.name

View File

@ -11,6 +11,7 @@ from pilot.connections.rdbms.conn_duckdb import DuckDbConnect
from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.connections.db_conn_info import DBConfig

View File

@ -0,0 +1,140 @@
import re
from typing import Optional, Any
from sqlalchemy import text
from pilot.connections.rdbms.base import RDBMSDatabase
class ClickhouseConnect(RDBMSDatabase):
"""Connect Clickhouse Database fetch MetaData
Args:
Usage:
"""
db_type: str = "clickhouse"
driver: str = "clickhouse"
db_dialect: str = "clickhouse"
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.driver
+ "://"
+ user
+ ":"
+ pwd
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
)
return cls.from_uri(db_url, engine_args, **kwargs)
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return """"""
def get_show_create_table(self, table_name):
"""Get table show create table about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
ans = cursor.fetchall()
ans = ans[0][0]
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
ans = re.sub(
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
)
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
return ans
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'".format(
table_name
)
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_users(self):
return []
def get_grants(self):
return []
def get_collation(self):
"""Get collation."""
return "UTF-8"
def get_charset(self):
return "UTF-8"
def get_database_list(self):
return []
def get_database_names(self):
return []
def get_table_comments(self, db_name):
session = self._db_sessions()
cursor = session.execute(
text(
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
db_name
)
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
# def get_table_comments(self, db_name=None):
# cursor = self.session.execute(
# text(
# f"""
# SELECT name, sql FROM sqlite_master WHERE type='table'
# """
# )
# )
# table_comments = cursor.fetchall()
# return [
# (table_comment[0], table_comment[1]) for table_comment in table_comments
# ]
#
# def table_simple_info(self) -> Iterable[str]:
# _tables_sql = f"""
# SELECT name FROM sqlite_master WHERE type='table'
# """
# cursor = self.session.execute(text(_tables_sql))
# tables_results = cursor.fetchall()
# results = []
# for row in tables_results:
# table_name = row[0]
# _sql = f"""
# PRAGMA table_info({table_name})
# """
# cursor_colums = self.session.execute(text(_sql))
# colum_results = cursor_colums.fetchall()
# table_colums = []
# for row_col in colum_results:
# field_info = list(row_col)
# table_colums.append(field_info[1])
#
# results.append(f"{table_name}({','.join(table_colums)});")
# return results

View File

@ -118,7 +118,7 @@ async def db_connect_delete(db_name: str = None):
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb, DBType.SQLite]
support_types = [DBType.Mysql, DBType.MSSQL, DBType.DuckDb, DBType.SQLite, DBType.Clickhouse]
db_type_infos = []
for type in support_types:
db_type_infos.append(