mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-15 14:11:14 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
1
dbgpt/datasource/__init__.py
Normal file
1
dbgpt/datasource/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .manages.connect_config_db import ConnectConfigEntity, ConnectConfigDao
|
59
dbgpt/datasource/base.py
Normal file
59
dbgpt/datasource/base.py
Normal file
@@ -0,0 +1,59 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
"""We need to design a base class. That other connector can Write with this"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
|
||||
class BaseConnect(ABC):
|
||||
def get_connect(self, db_name: str):
|
||||
pass
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
pass
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def get_index_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
pass
|
||||
|
||||
def get_example_data(self, table: str, count: int = 3):
|
||||
pass
|
||||
|
||||
def get_database_list(self):
|
||||
pass
|
||||
|
||||
def get_database_names(self):
|
||||
pass
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
pass
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
pass
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
pass
|
||||
|
||||
def get_users(self):
|
||||
pass
|
||||
|
||||
def get_grants(self):
|
||||
pass
|
||||
|
||||
def get_collation(self):
|
||||
pass
|
||||
|
||||
def get_charset(self):
|
||||
pass
|
||||
|
||||
def get_fields(self, table_name):
|
||||
pass
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
pass
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
pass
|
119
dbgpt/datasource/conn_spark.py
Normal file
119
dbgpt/datasource/conn_spark.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
|
||||
|
||||
class SparkConnect(BaseConnect):
|
||||
"""
|
||||
Spark Connect supports operating on a variety of data sources through the DataFrame interface.
|
||||
A DataFrame can be operated on using relational transformations and can also be used to create a temporary view.
|
||||
Registering a DataFrame as a temporary view allows you to run SQL queries over its data.
|
||||
Datasource now support parquet, jdbc, orc, libsvm, csv, text, json.
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "spark"
|
||||
"""db driver"""
|
||||
driver: str = "spark"
|
||||
"""db dialect"""
|
||||
dialect: str = "sparksql"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
spark_session: Optional = None,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Spark DataFrame from Datasource path
|
||||
return: Spark DataFrame
|
||||
"""
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
self.spark_session = (
|
||||
spark_session or SparkSession.builder.appName("dbgpt_spark").getOrCreate()
|
||||
)
|
||||
self.path = file_path
|
||||
self.table_name = "temp"
|
||||
self.df = self.create_df(self.path)
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
):
|
||||
try:
|
||||
return cls(file_path=file_path, engine_args=engine_args)
|
||||
|
||||
except Exception as e:
|
||||
print("load spark datasource error" + str(e))
|
||||
|
||||
def create_df(self, path):
|
||||
"""Create a Spark DataFrame from Datasource path(now support parquet, jdbc, orc, libsvm, csv, text, json.).
|
||||
return: Spark DataFrame
|
||||
reference:https://spark.apache.org/docs/latest/sql-data-sources-load-save-functions.html
|
||||
"""
|
||||
extension = (
|
||||
"text" if path.rsplit(".", 1)[-1] == "txt" else path.rsplit(".", 1)[-1]
|
||||
)
|
||||
return self.spark_session.read.load(
|
||||
path, format=extension, inferSchema="true", header="true"
|
||||
)
|
||||
|
||||
def run(self, sql):
|
||||
print(f"spark sql to run is {sql}")
|
||||
self.df.createOrReplaceTempView(self.table_name)
|
||||
df = self.spark_session.sql(sql)
|
||||
first_row = df.first()
|
||||
rows = [first_row.asDict().keys()]
|
||||
for row in df.collect():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def query_ex(self, sql):
|
||||
rows = self.run(sql)
|
||||
field_names = rows[0]
|
||||
return field_names, rows
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
|
||||
return "ans"
|
||||
|
||||
def get_fields(self):
|
||||
"""Get column meta about dataframe."""
|
||||
return ",".join([f"({name}: {dtype})" for name, dtype in self.df.dtypes])
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_db_list(self):
|
||||
return ["default"]
|
||||
|
||||
def get_db_names(self):
|
||||
return ["default"]
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def table_simple_info(self):
|
||||
return f"{self.table_name}{self.get_fields()}"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
return ""
|
17
dbgpt/datasource/db_conn_info.py
Normal file
17
dbgpt/datasource/db_conn_info.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DBConfig(BaseModel):
|
||||
db_type: str
|
||||
db_name: str
|
||||
file_path: str = ""
|
||||
db_host: str = ""
|
||||
db_port: int = 0
|
||||
db_user: str = ""
|
||||
db_pwd: str = ""
|
||||
comment: str = ""
|
||||
|
||||
|
||||
class DbTypeInfo(BaseModel):
|
||||
db_type: str
|
||||
is_file_db: bool = False
|
0
dbgpt/datasource/manages/__init__.py
Normal file
0
dbgpt/datasource/manages/__init__.py
Normal file
244
dbgpt/datasource/manages/connect_config_db.py
Normal file
244
dbgpt/datasource/manages/connect_config_db.py
Normal file
@@ -0,0 +1,244 @@
|
||||
from sqlalchemy import Column, Integer, String, Index, Text, text
|
||||
from sqlalchemy import UniqueConstraint
|
||||
|
||||
from dbgpt.storage.metadata import BaseDao
|
||||
from dbgpt.storage.metadata.meta_data import (
|
||||
Base,
|
||||
engine,
|
||||
session,
|
||||
META_DATA_DATABASE,
|
||||
)
|
||||
|
||||
|
||||
class ConnectConfigEntity(Base):
|
||||
"""db connect config entity"""
|
||||
|
||||
__tablename__ = "connect_config"
|
||||
id = Column(
|
||||
Integer, primary_key=True, autoincrement=True, comment="autoincrement id"
|
||||
)
|
||||
|
||||
db_type = Column(String(255), nullable=False, comment="db type")
|
||||
db_name = Column(String(255), nullable=False, comment="db name")
|
||||
db_path = Column(String(255), nullable=True, comment="file db path")
|
||||
db_host = Column(String(255), nullable=True, comment="db connect host(not file db)")
|
||||
db_port = Column(String(255), nullable=True, comment="db connect port(not file db)")
|
||||
db_user = Column(String(255), nullable=True, comment="db user")
|
||||
db_pwd = Column(String(255), nullable=True, comment="db password")
|
||||
comment = Column(Text, nullable=True, comment="db comment")
|
||||
sys_code = Column(String(128), index=True, nullable=True, comment="System code")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("db_name", name="uk_db"),
|
||||
Index("idx_q_db_type", "db_type"),
|
||||
{"mysql_charset": "utf8mb4", "mysql_collate": "utf8mb4_unicode_ci"},
|
||||
)
|
||||
|
||||
|
||||
class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
|
||||
"""db connect config dao"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
database=META_DATA_DATABASE,
|
||||
orm_base=Base,
|
||||
db_engine=engine,
|
||||
session=session,
|
||||
)
|
||||
|
||||
def update(self, entity: ConnectConfigEntity):
|
||||
"""update db connect info"""
|
||||
session = self.get_session()
|
||||
try:
|
||||
updated = session.merge(entity)
|
||||
session.commit()
|
||||
return updated.id
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def delete(self, db_name: int):
|
||||
""" "delete db connect info"""
|
||||
session = self.get_session()
|
||||
if db_name is None:
|
||||
raise Exception("db_name is None")
|
||||
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
db_connect.delete()
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
def get_by_names(self, db_name: str) -> ConnectConfigEntity:
|
||||
"""get db connect info by name"""
|
||||
session = self.get_session()
|
||||
db_connect = session.query(ConnectConfigEntity)
|
||||
db_connect = db_connect.filter(ConnectConfigEntity.db_name == db_name)
|
||||
result = db_connect.first()
|
||||
session.close()
|
||||
return result
|
||||
|
||||
def add_url_db(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_host: str,
|
||||
db_port: int,
|
||||
db_user: str,
|
||||
db_pwd: str,
|
||||
comment: str = "",
|
||||
):
|
||||
"""
|
||||
add db connect info
|
||||
Args:
|
||||
db_name: db name
|
||||
db_type: db type
|
||||
db_host: db host
|
||||
db_port: db port
|
||||
db_user: db user
|
||||
db_pwd: db password
|
||||
comment: comment
|
||||
"""
|
||||
try:
|
||||
session = self.get_session()
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config (
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
|
||||
) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
params = {
|
||||
"db_name": db_name,
|
||||
"db_type": db_type,
|
||||
"db_path": "",
|
||||
"db_host": db_host,
|
||||
"db_port": db_port,
|
||||
"db_user": db_user,
|
||||
"db_pwd": db_pwd,
|
||||
"comment": comment,
|
||||
}
|
||||
session.execute(insert_statement, params)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("add db connect info error!" + str(e))
|
||||
|
||||
def update_db_info(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_path: str = "",
|
||||
db_host: str = "",
|
||||
db_port: int = 0,
|
||||
db_user: str = "",
|
||||
db_pwd: str = "",
|
||||
comment: str = "",
|
||||
):
|
||||
"""update db connect info"""
|
||||
old_db_conf = self.get_db_config(db_name)
|
||||
if old_db_conf:
|
||||
try:
|
||||
session = self.get_session()
|
||||
if not db_path:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
else:
|
||||
update_statement = text(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
session.execute(update_statement)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("edit db connect info error!" + str(e))
|
||||
return True
|
||||
raise ValueError(f"{db_name} not have config info!")
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
"""add file db connect info"""
|
||||
try:
|
||||
session = self.get_session()
|
||||
insert_statement = text(
|
||||
"""
|
||||
INSERT INTO connect_config(
|
||||
db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment
|
||||
) VALUES (
|
||||
:db_name, :db_type, :db_path, :db_host, :db_port, :db_user, :db_pwd, :comment
|
||||
)
|
||||
"""
|
||||
)
|
||||
params = {
|
||||
"db_name": db_name,
|
||||
"db_type": db_type,
|
||||
"db_path": db_path,
|
||||
"db_host": "",
|
||||
"db_port": 0,
|
||||
"db_user": "",
|
||||
"db_pwd": "",
|
||||
"comment": comment,
|
||||
}
|
||||
|
||||
session.execute(insert_statement, params)
|
||||
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
print("add db connect info error!" + str(e))
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
"""get db config by name"""
|
||||
session = self.get_session()
|
||||
if db_name:
|
||||
select_statement = text(
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
connect_config
|
||||
WHERE
|
||||
db_name = :db_name
|
||||
"""
|
||||
)
|
||||
params = {"db_name": db_name}
|
||||
result = session.execute(select_statement, params)
|
||||
|
||||
else:
|
||||
raise ValueError("Cannot get database by name" + db_name)
|
||||
|
||||
fields = [field[0] for field in result.cursor.description]
|
||||
row_dict = {}
|
||||
row_1 = list(result.cursor.fetchall()[0])
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row_1[i]
|
||||
return row_dict
|
||||
|
||||
def get_db_list(self):
|
||||
"""get db list"""
|
||||
session = self.get_session()
|
||||
result = session.execute(text("SELECT * FROM connect_config"))
|
||||
|
||||
fields = [field[0] for field in result.cursor.description]
|
||||
data = []
|
||||
for row in result.cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
return data
|
||||
|
||||
def delete_db(self, db_name):
|
||||
"""delete db connect info"""
|
||||
session = self.get_session()
|
||||
delete_statement = text("""DELETE FROM connect_config where db_name=:db_name""")
|
||||
params = {"db_name": db_name}
|
||||
session.execute(delete_statement, params)
|
||||
session.commit()
|
||||
session.close()
|
||||
return True
|
149
dbgpt/datasource/manages/connect_storage_duckdb.py
Normal file
149
dbgpt/datasource/manages/connect_storage_duckdb.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import duckdb
|
||||
from dbgpt.configs.model_config import PILOT_PATH
|
||||
|
||||
default_db_path = os.path.join(PILOT_PATH, "message")
|
||||
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/connect_config.db")
|
||||
table_name = "connect_config"
|
||||
|
||||
|
||||
class DuckdbConnectConfig:
|
||||
def __init__(self):
|
||||
os.makedirs(default_db_path, exist_ok=True)
|
||||
self.connect = duckdb.connect(duckdb_path)
|
||||
self.__init_config_tables()
|
||||
|
||||
def __init_config_tables(self):
|
||||
# check config table
|
||||
result = self.connect.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name=?", [table_name]
|
||||
).fetchall()
|
||||
|
||||
if not result:
|
||||
# create config table
|
||||
self.connect.execute(
|
||||
"CREATE TABLE connect_config (id integer primary key, db_name VARCHAR(100) UNIQUE, db_type VARCHAR(50), db_path VARCHAR(255) NULL, db_host VARCHAR(255) NULL, db_port INTEGER NULL, db_user VARCHAR(255) NULL, db_pwd VARCHAR(255) NULL, comment TEXT NULL)"
|
||||
)
|
||||
self.connect.execute("CREATE SEQUENCE seq_id START 1;")
|
||||
|
||||
def add_url_db(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_host: str,
|
||||
db_port: int,
|
||||
db_user: str,
|
||||
db_pwd: str,
|
||||
comment: str = "",
|
||||
):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, "", db_host, db_port, db_user, db_pwd, comment],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("add db connect info error1!" + str(e))
|
||||
|
||||
def update_db_info(
|
||||
self,
|
||||
db_name,
|
||||
db_type,
|
||||
db_path: str = "",
|
||||
db_host: str = "",
|
||||
db_port: int = 0,
|
||||
db_user: str = "",
|
||||
db_pwd: str = "",
|
||||
comment: str = "",
|
||||
):
|
||||
old_db_conf = self.get_db_config(db_name)
|
||||
if old_db_conf:
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
if not db_path:
|
||||
cursor.execute(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_host='{db_host}', db_port={db_port}, db_user='{db_user}', db_pwd='{db_pwd}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
else:
|
||||
cursor.execute(
|
||||
f"UPDATE connect_config set db_type='{db_type}', db_path='{db_path}', comment='{comment}' where db_name='{db_name}'"
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("edit db connect info error2!" + str(e))
|
||||
return True
|
||||
raise ValueError(f"{db_name} not have config info!")
|
||||
|
||||
def get_file_db_name(self, path):
|
||||
try:
|
||||
conn = duckdb.connect(path)
|
||||
result = conn.execute("SELECT current_database()").fetchone()[0]
|
||||
return result
|
||||
except Exception as e:
|
||||
raise "Unusable duckdb database path:" + path
|
||||
|
||||
def add_file_db(self, db_name, db_type, db_path: str, comment: str = ""):
|
||||
try:
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute(
|
||||
"INSERT INTO connect_config(id, db_name, db_type, db_path, db_host, db_port, db_user, db_pwd, comment)VALUES(nextval('seq_id'),?,?,?,?,?,?,?,?)",
|
||||
[db_name, db_type, db_path, "", 0, "", "", comment],
|
||||
)
|
||||
cursor.commit()
|
||||
self.connect.commit()
|
||||
except Exception as e:
|
||||
print("add db connect info error2!" + str(e))
|
||||
|
||||
def delete_db(self, db_name):
|
||||
cursor = self.connect.cursor()
|
||||
cursor.execute("DELETE FROM connect_config where db_name=?", [db_name])
|
||||
cursor.commit()
|
||||
return True
|
||||
|
||||
def get_db_config(self, db_name):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
if db_name:
|
||||
cursor.execute(
|
||||
"SELECT * FROM connect_config where db_name=? ", [db_name]
|
||||
)
|
||||
else:
|
||||
raise ValueError("Cannot get database by name" + db_name)
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
row_dict = {}
|
||||
row_1 = list(cursor.fetchall()[0])
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row_1[i]
|
||||
return row_dict
|
||||
return None
|
||||
|
||||
def get_db_list(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute("SELECT * FROM connect_config ")
|
||||
|
||||
fields = [field[0] for field in cursor.description]
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
row_dict = {}
|
||||
for i, field in enumerate(fields):
|
||||
row_dict[field] = row[i]
|
||||
data.append(row_dict)
|
||||
return data
|
||||
|
||||
return []
|
||||
|
||||
def get_db_names(self):
|
||||
if os.path.isfile(duckdb_path):
|
||||
cursor = duckdb.connect(duckdb_path).cursor()
|
||||
cursor.execute("SELECT db_name FROM connect_config ")
|
||||
data = []
|
||||
for row in cursor.fetchall():
|
||||
data.append(row[0])
|
||||
return data
|
||||
return []
|
150
dbgpt/datasource/manages/connection_manager.py
Normal file
150
dbgpt/datasource/manages/connection_manager.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from dbgpt.datasource import ConnectConfigDao
|
||||
from dbgpt.storage.schema import DBType
|
||||
from dbgpt.component import SystemApp, ComponentType
|
||||
from dbgpt.util.executor_utils import ExecutorFactory
|
||||
|
||||
from dbgpt.datasource.db_conn_info import DBConfig
|
||||
from dbgpt.rag.summary.db_summary_client import DBSummaryClient
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
from dbgpt.datasource.rdbms.conn_mysql import MySQLConnect
|
||||
from dbgpt.datasource.rdbms.conn_duckdb import DuckDbConnect
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
|
||||
from dbgpt.datasource.rdbms.conn_mssql import MSSQLConnect
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.datasource.rdbms.conn_clickhouse import ClickhouseConnect
|
||||
from dbgpt.datasource.rdbms.conn_postgresql import PostgreSQLDatabase
|
||||
from dbgpt.datasource.rdbms.conn_starrocks import StarRocksConnect
|
||||
from dbgpt.datasource.conn_spark import SparkConnect
|
||||
from dbgpt.datasource.rdbms.conn_doris import DorisConnect
|
||||
|
||||
|
||||
class ConnectManager:
|
||||
"""db connect manager"""
|
||||
|
||||
def get_all_subclasses(self, cls):
|
||||
subclasses = cls.__subclasses__()
|
||||
for subclass in subclasses:
|
||||
subclasses += self.get_all_subclasses(subclass)
|
||||
return subclasses
|
||||
|
||||
def get_all_completed_types(self):
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
support_types = []
|
||||
for cls in chat_classes:
|
||||
if cls.db_type:
|
||||
support_types.append(DBType.of_db_type(cls.db_type))
|
||||
return support_types
|
||||
|
||||
def get_cls_by_dbtype(self, db_type):
|
||||
chat_classes = self.get_all_subclasses(BaseConnect)
|
||||
result = None
|
||||
for cls in chat_classes:
|
||||
if cls.db_type == db_type:
|
||||
result = cls
|
||||
if not result:
|
||||
raise ValueError("Unsupported Db Type!" + db_type)
|
||||
return result
|
||||
|
||||
def __init__(self, system_app: SystemApp):
|
||||
"""metadata database management initialization"""
|
||||
# self.storage = DuckdbConnectConfig()
|
||||
self.storage = ConnectConfigDao()
|
||||
self.system_app = system_app
|
||||
self.db_summary_client = DBSummaryClient(system_app)
|
||||
|
||||
def get_connect(self, db_name):
|
||||
db_config = self.storage.get_db_config(db_name)
|
||||
db_type = DBType.of_db_type(db_config.get("db_type"))
|
||||
connect_instance = self.get_cls_by_dbtype(db_type.value())
|
||||
if db_type.is_file_db():
|
||||
db_path = db_config.get("db_path")
|
||||
return connect_instance.from_file_path(db_path)
|
||||
else:
|
||||
db_host = db_config.get("db_host")
|
||||
db_port = db_config.get("db_port")
|
||||
db_user = db_config.get("db_user")
|
||||
db_pwd = db_config.get("db_pwd")
|
||||
return connect_instance.from_uri_db(
|
||||
host=db_host, port=db_port, user=db_user, pwd=db_pwd, db_name=db_name
|
||||
)
|
||||
|
||||
def test_connect(self, db_info: DBConfig):
|
||||
try:
|
||||
db_type = DBType.of_db_type(db_info.db_type)
|
||||
connect_instance = self.get_cls_by_dbtype(db_type.value())
|
||||
if db_type.is_file_db():
|
||||
db_path = db_info.file_path
|
||||
return connect_instance.from_file_path(db_path)
|
||||
else:
|
||||
db_name = db_info.db_name
|
||||
db_host = db_info.db_host
|
||||
db_port = db_info.db_port
|
||||
db_user = db_info.db_user
|
||||
db_pwd = db_info.db_pwd
|
||||
return connect_instance.from_uri_db(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
user=db_user,
|
||||
pwd=db_pwd,
|
||||
db_name=db_name,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"{db_info.db_name} Test connect Failure!{str(e)}")
|
||||
raise ValueError(f"{db_info.db_name} Test connect Failure!{str(e)}")
|
||||
|
||||
def get_db_list(self):
|
||||
return self.storage.get_db_list()
|
||||
|
||||
def get_db_names(self):
|
||||
return self.storage.get_by_name()
|
||||
|
||||
def delete_db(self, db_name: str):
|
||||
return self.storage.delete_db(db_name)
|
||||
|
||||
def edit_db(self, db_info: DBConfig):
|
||||
return self.storage.update_db_info(
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
db_info.file_path,
|
||||
db_info.db_host,
|
||||
db_info.db_port,
|
||||
db_info.db_user,
|
||||
db_info.db_pwd,
|
||||
db_info.comment,
|
||||
)
|
||||
|
||||
async def async_db_summary_embedding(self, db_name, db_type):
|
||||
# 在这里执行需要异步运行的代码
|
||||
self.db_summary_client.db_summary_embedding(db_name, db_type)
|
||||
|
||||
def add_db(self, db_info: DBConfig):
|
||||
print(f"add_db:{db_info.__dict__}")
|
||||
try:
|
||||
db_type = DBType.of_db_type(db_info.db_type)
|
||||
if db_type.is_file_db():
|
||||
self.storage.add_file_db(
|
||||
db_info.db_name, db_info.db_type, db_info.file_path
|
||||
)
|
||||
else:
|
||||
self.storage.add_url_db(
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
db_info.db_host,
|
||||
db_info.db_port,
|
||||
db_info.db_user,
|
||||
db_info.db_pwd,
|
||||
db_info.comment,
|
||||
)
|
||||
# async embedding
|
||||
executor = self.system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
executor.submit(
|
||||
self.db_summary_client.db_summary_embedding,
|
||||
db_info.db_name,
|
||||
db_info.db_type,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError("Add db connect info error!" + str(e))
|
||||
|
||||
return True
|
0
dbgpt/datasource/nosql/__init__.py
Normal file
0
dbgpt/datasource/nosql/__init__.py
Normal file
0
dbgpt/datasource/rdbms/__init__.py
Normal file
0
dbgpt/datasource/rdbms/__init__.py
Normal file
84
dbgpt/datasource/rdbms/_base_dao.py
Normal file
84
dbgpt/datasource/rdbms/_base_dao.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import logging
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from dbgpt._private.config import Config
|
||||
from dbgpt.storage.schema import DBType
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class BaseDao:
|
||||
def __init__(
|
||||
self, orm_base=None, database: str = None, create_not_exist_table: bool = False
|
||||
) -> None:
|
||||
"""BaseDAO, If the current database is a file database and create_not_exist_table=True, we will automatically create a table that does not exist"""
|
||||
self._orm_base = orm_base
|
||||
self._database = database
|
||||
self._create_not_exist_table = create_not_exist_table
|
||||
|
||||
self._db_engine = None
|
||||
self._session = None
|
||||
self._connection = None
|
||||
|
||||
@property
|
||||
def db_engine(self):
|
||||
if not self._db_engine:
|
||||
# lazy loading
|
||||
db_engine, connection = _get_db_engine(
|
||||
self._orm_base, self._database, self._create_not_exist_table
|
||||
)
|
||||
self._db_engine = db_engine
|
||||
self._connection = connection
|
||||
return self._db_engine
|
||||
|
||||
@property
|
||||
def Session(self):
|
||||
if not self._session:
|
||||
self._session = sessionmaker(bind=self.db_engine)
|
||||
return self._session
|
||||
|
||||
|
||||
def _get_db_engine(
|
||||
orm_base=None, database: str = None, create_not_exist_table: bool = False
|
||||
):
|
||||
db_engine = None
|
||||
connection: RDBMSDatabase = None
|
||||
|
||||
db_type = DBType.of_db_type(CFG.LOCAL_DB_TYPE)
|
||||
if db_type is None or db_type == DBType.Mysql:
|
||||
# default database
|
||||
db_engine = create_engine(
|
||||
f"mysql+pymysql://{CFG.LOCAL_DB_USER}:{CFG.LOCAL_DB_PASSWORD}@{CFG.LOCAL_DB_HOST}:{CFG.LOCAL_DB_PORT}/{database}",
|
||||
echo=True,
|
||||
)
|
||||
else:
|
||||
db_namager = CFG.LOCAL_DB_MANAGE
|
||||
if not db_namager:
|
||||
raise Exception(
|
||||
"LOCAL_DB_MANAGE is not initialized, please check the system configuration"
|
||||
)
|
||||
if db_type.is_file_db():
|
||||
db_path = CFG.LOCAL_DB_PATH
|
||||
if db_path is None or db_path == "":
|
||||
raise ValueError(
|
||||
"You LOCAL_DB_TYPE is file db, but LOCAL_DB_PATH is not configured, please configure LOCAL_DB_PATH in you .env file"
|
||||
)
|
||||
_, database = db_namager._parse_file_db_info(db_type.value(), db_path)
|
||||
logger.info(
|
||||
f"Current DAO database is file database, db_type: {db_type.value()}, db_path: {db_path}, db_name: {database}"
|
||||
)
|
||||
logger.info(f"Get DAO database connection with database name {database}")
|
||||
connection: RDBMSDatabase = db_namager.get_connect(database)
|
||||
if not isinstance(connection, RDBMSDatabase):
|
||||
raise ValueError(
|
||||
"Currently only supports `RDBMSDatabase` database as the underlying database of BaseDao, please check your database configuration"
|
||||
)
|
||||
db_engine = connection._engine
|
||||
|
||||
if db_type.is_file_db() and orm_base is not None and create_not_exist_table:
|
||||
logger.info("Current database is file database, create not exist table")
|
||||
orm_base.metadata.create_all(db_engine)
|
||||
|
||||
return db_engine, connection
|
549
dbgpt/datasource/rdbms/base.py
Normal file
549
dbgpt/datasource/rdbms/base.py
Normal file
@@ -0,0 +1,549 @@
|
||||
from __future__ import annotations
|
||||
import sqlparse
|
||||
import regex as re
|
||||
import pandas as pd
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
from typing import Any, Iterable, List, Optional
|
||||
import sqlalchemy
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
|
||||
from sqlalchemy.schema import CreateTable
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
|
||||
from dbgpt.storage.schema import DBType
|
||||
from dbgpt.datasource.base import BaseConnect
|
||||
from dbgpt._private.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
|
||||
return (
|
||||
f'Name: {index["name"]}, Unique: {index["unique"]},'
|
||||
f' Columns: {str(index["column_names"])}'
|
||||
)
|
||||
|
||||
|
||||
class RDBMSDatabase(BaseConnect):
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
db_type: str = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Create engine from database URI.
|
||||
Args:
|
||||
- engine: Engine sqlalchemy.engine
|
||||
- schema: Optional[str].
|
||||
- metadata: Optional[MetaData]
|
||||
- ignore_tables: Optional[List[str]]
|
||||
- include_tables: Optional[List[str]]
|
||||
- sample_rows_in_table_info: int default:3,
|
||||
- indexes_in_table_info: bool = False,
|
||||
- custom_table_info: Optional[dict] = None,
|
||||
- view_support: bool = False,
|
||||
"""
|
||||
self._engine = engine
|
||||
self._schema = schema
|
||||
if include_tables and ignore_tables:
|
||||
raise ValueError("Cannot specify both include_tables and ignore_tables")
|
||||
|
||||
self._inspector = inspect(engine)
|
||||
session_factory = sessionmaker(bind=engine)
|
||||
Session_Manages = scoped_session(session_factory)
|
||||
self._db_sessions = Session_Manages
|
||||
self.session = self.get_session()
|
||||
|
||||
self._all_tables = set()
|
||||
self.view_support = False
|
||||
self._usable_tables = set()
|
||||
self._include_tables = set()
|
||||
self._ignore_tables = set()
|
||||
self._custom_table_info = set()
|
||||
self._indexes_in_table_info = set()
|
||||
self._usable_tables = set()
|
||||
self._usable_tables = set()
|
||||
self._sample_rows_in_table_info = set()
|
||||
self._indexes_in_table_info = indexes_in_table_info
|
||||
|
||||
self._metadata = MetaData()
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
|
||||
self._all_tables = self._sync_tables_from_db()
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from uri database.
|
||||
Args:
|
||||
host (str): database host.
|
||||
port (int): database port.
|
||||
user (str): database user.
|
||||
pwd (str): database password.
|
||||
db_name (str): database name.
|
||||
engine_args (Optional[dict]):other engine_args.
|
||||
"""
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine(database_uri, **_engine_args), **kwargs)
|
||||
|
||||
@property
|
||||
def dialect(self) -> str:
|
||||
"""Return string representation of dialect to use."""
|
||||
return self._engine.dialect.name
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
"""Read table information from database"""
|
||||
# TODO Use a background thread to refresh periodically
|
||||
|
||||
# SQL will raise error with schema
|
||||
_schema = (
|
||||
None if self.db_type == DBType.SQLite.value() else self._engine.url.database
|
||||
)
|
||||
# including view support by adding the views as well as tables to the all
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=_schema)
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=_schema)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
return self._all_tables
|
||||
|
||||
def get_usable_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
if self._include_tables:
|
||||
return self._include_tables
|
||||
return self._all_tables - self._ignore_tables
|
||||
|
||||
def get_table_names(self) -> Iterable[str]:
|
||||
"""Get names of tables available."""
|
||||
return self.get_usable_table_names()
|
||||
|
||||
def get_session(self):
|
||||
session = self._db_sessions()
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.session.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name()}" group by TABLE_NAME;
|
||||
"""
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
@property
|
||||
def table_info(self) -> str:
|
||||
"""Information about all tables in the database."""
|
||||
return self.get_table_info()
|
||||
|
||||
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables.
|
||||
|
||||
Follows best practices as specified in: Rajkumar et al, 2022
|
||||
(https://arxiv.org/abs/2204.00498)
|
||||
|
||||
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
||||
appended to each table description. This can increase performance as
|
||||
demonstrated in the paper.
|
||||
"""
|
||||
all_table_names = self.get_usable_table_names()
|
||||
if table_names is not None:
|
||||
missing_tables = set(table_names).difference(all_table_names)
|
||||
if missing_tables:
|
||||
raise ValueError(f"table_names {missing_tables} not found in database")
|
||||
all_table_names = table_names
|
||||
|
||||
meta_tables = [
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
for table in meta_tables:
|
||||
if self._custom_table_info and table.name in self._custom_table_info:
|
||||
tables.append(self._custom_table_info[table.name])
|
||||
continue
|
||||
|
||||
# add create table command
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
if self._indexes_in_table_info:
|
||||
table_info += f"\n{self._get_table_indexes(table)}\n"
|
||||
if self._sample_rows_in_table_info:
|
||||
table_info += f"\n{self._get_sample_rows(table)}\n"
|
||||
if has_extra_info:
|
||||
table_info += "*/"
|
||||
tables.append(table_info)
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def _get_sample_rows(self, table: Table) -> str:
|
||||
# build the select command
|
||||
command = select(table).limit(self._sample_rows_in_table_info)
|
||||
|
||||
# save the columns in string format
|
||||
columns_str = "\t".join([col.name for col in table.columns])
|
||||
|
||||
try:
|
||||
# get the sample rows
|
||||
with self._engine.connect() as connection:
|
||||
sample_rows_result: CursorResult = connection.execute(command)
|
||||
# shorten values in the sample rows
|
||||
sample_rows = list(
|
||||
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
||||
)
|
||||
|
||||
# save the sample rows in string format
|
||||
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
|
||||
|
||||
# in some dialects when there are no rows in the table a
|
||||
# 'ProgrammingError' is returned
|
||||
except ProgrammingError:
|
||||
sample_rows_str = ""
|
||||
|
||||
return (
|
||||
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
||||
f"{columns_str}\n"
|
||||
f"{sample_rows_str}"
|
||||
)
|
||||
|
||||
def _get_table_indexes(self, table: Table) -> str:
|
||||
indexes = self._inspector.get_indexes(table.name)
|
||||
indexes_formatted = "\n".join(map(_format_index, indexes))
|
||||
return f"Table Indexes:\n{indexes_formatted}"
|
||||
|
||||
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
|
||||
"""Get information about specified tables."""
|
||||
try:
|
||||
return self.get_table_info(table_names)
|
||||
except ValueError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def __write(self, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
db_cache = self._engine.url.database
|
||||
result = self.session.execute(text(write_sql))
|
||||
self.session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
self.session.execute(text(f"use `{db_cache}`"))
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def __query(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = self.session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_ex(self, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
Returns:
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = self.session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
return []
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
if not command or len(command) < 0:
|
||||
return []
|
||||
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
|
||||
if ttype == sqlparse.tokens.DML:
|
||||
if sql_type == "SELECT":
|
||||
return self.__query(command, fetch)
|
||||
else:
|
||||
self.__write(command)
|
||||
select_sql = self.convert_sql_write_to_select(command)
|
||||
print(f"write result query:{select_sql}")
|
||||
return self.__query(select_sql)
|
||||
|
||||
else:
|
||||
print(f"DDL execution determines whether to enable through configuration ")
|
||||
cursor = self.session.execute(text(command))
|
||||
self.session.commit()
|
||||
if cursor.returns_rows:
|
||||
result = cursor.fetchall()
|
||||
field_names = tuple(i[0:] for i in cursor.keys())
|
||||
result = list(result)
|
||||
result.insert(0, field_names)
|
||||
print("DDL Result:" + str(result))
|
||||
if not result:
|
||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
||||
return result
|
||||
else:
|
||||
return self.__query(f"SHOW COLUMNS FROM {table_name}")
|
||||
|
||||
def run_to_df(self, command: str, fetch: str = "all"):
|
||||
result_lst = self.run(command, fetch)
|
||||
colunms = result_lst[0]
|
||||
values = result_lst[1:]
|
||||
return pd.DataFrame(values, columns=colunms)
|
||||
|
||||
def run_no_throw(self, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
If the statement throws an error, the error message is returned.
|
||||
"""
|
||||
try:
|
||||
return self.run(command, fetch)
|
||||
except SQLAlchemyError as e:
|
||||
"""Format the error message"""
|
||||
return f"Error: {e}"
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
||||
def convert_sql_write_to_select(self, write_sql):
|
||||
"""
|
||||
SQL classification processing
|
||||
author:xiangh8
|
||||
Args:
|
||||
sql:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
# 将SQL命令转换为小写,并按空格拆分
|
||||
parts = write_sql.lower().split()
|
||||
# 获取命令类型(insert, delete, update)
|
||||
cmd_type = parts[0]
|
||||
|
||||
# 根据命令类型进行处理
|
||||
if cmd_type == "insert":
|
||||
match = re.match(
|
||||
r"insert into (\w+) \((.*?)\) values \((.*?)\)", write_sql.lower()
|
||||
)
|
||||
if match:
|
||||
table_name, columns, values = match.groups()
|
||||
# 将字段列表和值列表分割为单独的字段和值
|
||||
columns = columns.split(",")
|
||||
values = values.split(",")
|
||||
# 构造 WHERE 子句
|
||||
where_clause = " AND ".join(
|
||||
[
|
||||
f"{col.strip()}={val.strip()}"
|
||||
for col, val in zip(columns, values)
|
||||
]
|
||||
)
|
||||
return f"SELECT * FROM {table_name} WHERE {where_clause}"
|
||||
|
||||
elif cmd_type == "delete":
|
||||
table_name = parts[2] # delete from <table_name> ...
|
||||
# 返回一个select语句,它选择该表的所有数据
|
||||
return f"SELECT * FROM {table_name} "
|
||||
|
||||
elif cmd_type == "update":
|
||||
table_name = parts[1]
|
||||
set_idx = parts.index("set")
|
||||
where_idx = parts.index("where")
|
||||
# 截取 `set` 子句中的字段名
|
||||
set_clause = parts[set_idx + 1 : where_idx][0].split("=")[0].strip()
|
||||
# 截取 `where` 之后的条件语句
|
||||
where_clause = " ".join(parts[where_idx + 1 :])
|
||||
# 返回一个select语句,它选择更新的数据
|
||||
return f"SELECT {set_clause} FROM {table_name} WHERE {where_clause}"
|
||||
else:
|
||||
raise ValueError(f"Unsupported SQL command type: {cmd_type}")
|
||||
|
||||
def __sql_parse(self, sql):
|
||||
sql = sql.strip()
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
sql_type = parsed.get_type()
|
||||
table_name = parsed.get_name()
|
||||
|
||||
first_token = parsed.token_first(skip_ws=True, skip_cm=False)
|
||||
ttype = first_token.ttype
|
||||
print(f"SQL:{sql}, ttype:{ttype}, sql_type:{sql_type}, table:{table_name}")
|
||||
return parsed, ttype, sql_type, table_name
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW INDEXES FROM {table_name}"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[2], index[4]) for index in indexes]
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
ans = cursor.fetchall()
|
||||
return ans[0][1]
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.COLUMNS where table_name='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@character_set_database"))
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT @@collation_database"))
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
|
||||
def get_grants(self):
|
||||
"""Get grant info."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
try:
|
||||
cursor = self.session.execute(text(f"SELECT user, host FROM mysql.user"))
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""SELECT table_name, table_comment FROM information_schema.tables WHERE table_schema = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
||||
def get_database_names(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
110
dbgpt/datasource/rdbms/conn_clickhouse.py
Normal file
110
dbgpt/datasource/rdbms/conn_clickhouse.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import re
|
||||
from typing import Optional, Any
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class ClickhouseConnect(RDBMSDatabase):
|
||||
"""Connect Clickhouse Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
"""db type"""
|
||||
db_type: str = "clickhouse"
|
||||
"""db driver"""
|
||||
driver: str = "clickhouse"
|
||||
"""db dialect"""
|
||||
db_dialect: str = "clickhouse"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
return ""
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW CREATE TABLE {table_name}"))
|
||||
ans = cursor.fetchall()
|
||||
ans = ans[0][0]
|
||||
ans = re.sub(r"\s*ENGINE\s*=\s*MergeTree\s*", " ", ans, flags=re.IGNORECASE)
|
||||
ans = re.sub(
|
||||
r"\s*DEFAULT\s*CHARSET\s*=\s*\w+\s*", " ", ans, flags=re.IGNORECASE
|
||||
)
|
||||
ans = re.sub(r"\s*SETTINGS\s*\s*\w+\s*", " ", ans, flags=re.IGNORECASE)
|
||||
return ans
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT name, type, default_expression, is_in_primary_key, comment from system.columns where table='{table_name}'".format(
|
||||
table_name
|
||||
)
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""SELECT table, comment FROM system.tables WHERE database = '{db_name}'""".format(
|
||||
db_name
|
||||
)
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def table_simple_info(self):
|
||||
# group_concat() not supported in clickhouse, use arrayStringConcat+groupArray instead; and quotes need to be escaped
|
||||
_sql = f"""
|
||||
select concat(TABLE_NAME, \'(\' , arrayStringConcat(groupArray(column_name),\'-\'), \')\') as schema_info
|
||||
from information_schema.COLUMNS where table_schema=\'{self.get_current_db_name()}\' group by TABLE_NAME; """
|
||||
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
159
dbgpt/datasource/rdbms/conn_doris.py
Normal file
159
dbgpt/datasource/rdbms/conn_doris.py
Normal file
@@ -0,0 +1,159 @@
|
||||
from typing import Iterable, Optional, Any
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class DorisConnect(RDBMSDatabase):
|
||||
driver = "doris"
|
||||
db_type = "doris"
|
||||
db_dialect = "doris"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.get_session().execute(
|
||||
text(
|
||||
f"SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=database()"
|
||||
)
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
self._all_tables = table_results
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
cursor = self.get_session().execute(text("SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
if len(grants) == 0:
|
||||
return []
|
||||
if len(grants[0]) == 2:
|
||||
grants_list = [x[1] for x in grants]
|
||||
else:
|
||||
grants_list = [x[2] for x in grants]
|
||||
return grants_list
|
||||
|
||||
def _get_current_version(self):
|
||||
"""Get database current version"""
|
||||
return int(
|
||||
self.get_session().execute(text("select current_version()")).scalar()
|
||||
)
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation.
|
||||
ref: https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-reference/Show-Statements/SHOW-COLLATION/
|
||||
"""
|
||||
cursor = self.get_session().execute(text("SHOW COLLATION"))
|
||||
results = cursor.fetchall()
|
||||
return "" if not results else results[0][0]
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
cursor = self.get_session().execute(
|
||||
text(
|
||||
f"select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT "
|
||||
f"from information_schema.columns "
|
||||
f'where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
return "utf-8"
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
# cur = self.get_session().execute(
|
||||
# text(
|
||||
# f"""show create table {table_name}"""
|
||||
# )
|
||||
# )
|
||||
# rows = cur.fetchone()
|
||||
# create_sql = rows[1]
|
||||
# return create_sql
|
||||
# 这里是要表描述, 返回建表语句会导致token过长而失败
|
||||
cur = self.get_session().execute(
|
||||
text(
|
||||
f"SELECT TABLE_COMMENT "
|
||||
f"FROM information_schema.tables "
|
||||
f'where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
|
||||
)
|
||||
)
|
||||
table = cur.fetchone()
|
||||
if table:
|
||||
return str(table[0])
|
||||
else:
|
||||
return ""
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
db_name = "database()" if not db_name else f"'{db_name}'"
|
||||
cursor = self.get_session().execute(
|
||||
text(
|
||||
f"SELECT TABLE_NAME,TABLE_COMMENT "
|
||||
f"FROM information_schema.tables "
|
||||
f"where TABLE_SCHEMA={db_name}"
|
||||
)
|
||||
)
|
||||
tables = cursor.fetchall()
|
||||
return [(table[0], table[1]) for table in tables]
|
||||
|
||||
def get_database_list(self):
|
||||
return self.get_database_names()
|
||||
|
||||
def get_database_names(self):
|
||||
cursor = self.get_session().execute(text("SHOW DATABASES"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0]
|
||||
not in [
|
||||
"information_schema",
|
||||
"sys",
|
||||
"_statistics_",
|
||||
"mysql",
|
||||
"__internal_schema",
|
||||
"doris_audit_db__",
|
||||
]
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.get_session().execute(text("select database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
cursor = self.get_session().execute(
|
||||
text(
|
||||
f"SELECT concat(TABLE_NAME,'(',group_concat(COLUMN_NAME,','),');') "
|
||||
f"FROM information_schema.columns "
|
||||
f"where TABLE_SCHEMA=database() "
|
||||
f"GROUP BY TABLE_NAME"
|
||||
)
|
||||
)
|
||||
results = cursor.fetchall()
|
||||
return [x[0] for x in results]
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
cursor = self.get_session().execute(text(f"SHOW INDEX FROM {table_name}"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[2], index[4]) for index in indexes]
|
79
dbgpt/datasource/rdbms/conn_duckdb.py
Normal file
79
dbgpt/datasource/rdbms/conn_duckdb.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import (
|
||||
create_engine,
|
||||
text,
|
||||
)
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class DuckDbConnect(RDBMSDatabase):
|
||||
"""Connect Duckdb Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "duckdb"
|
||||
db_dialect: str = "duckdb"
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
return cls(create_engine("duckdb:///" + file_path, **_engine_args), **kwargs)
|
||||
|
||||
def get_users(self):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'duckdb_sys_users';"
|
||||
)
|
||||
)
|
||||
users = cursor.fetchall()
|
||||
return [(user[0], user[1]) for user in users]
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_table_comments(self, db_name):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT name, sql FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
PRAGMA table_info({table_name})
|
||||
"""
|
||||
cursor_colums = self.session.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
table_colums = []
|
||||
for row_col in colum_results:
|
||||
field_info = list(row_col)
|
||||
table_colums.append(field_info[1])
|
||||
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
47
dbgpt/datasource/rdbms/conn_mssql.py
Normal file
47
dbgpt/datasource/rdbms/conn_mssql.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import Optional, Any, Iterable
|
||||
|
||||
from sqlalchemy import (
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class MSSQLConnect(RDBMSDatabase):
|
||||
"""Connect MSSQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "mssql"
|
||||
db_dialect: str = "mssql"
|
||||
driver: str = "mssql+pymssql"
|
||||
|
||||
default_db = ["master", "model", "msdb", "tempdb", "modeldb", "resource", "sys"]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='BASE TABLE'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
SELECT COLUMN_NAME, DATA_TYPE FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME='{table_name}'
|
||||
"""
|
||||
cursor_colums = self.session.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
table_colums = []
|
||||
for row_col in colum_results:
|
||||
field_info = list(row_col)
|
||||
table_colums.append(field_info[0])
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
16
dbgpt/datasource/rdbms/conn_mysql.py
Normal file
16
dbgpt/datasource/rdbms/conn_mysql.py
Normal file
@@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class MySQLConnect(RDBMSDatabase):
|
||||
"""Connect MySQL Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "mysql"
|
||||
db_dialect: str = "mysql"
|
||||
driver: str = "mysql+pymysql"
|
||||
|
||||
default_db = ["information_schema", "performance_schema", "sys", "mysql"]
|
208
dbgpt/datasource/rdbms/conn_postgresql.py
Normal file
208
dbgpt/datasource/rdbms/conn_postgresql.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from typing import Iterable, Optional, Any
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class PostgreSQLDatabase(RDBMSDatabase):
|
||||
driver = "postgresql+psycopg2"
|
||||
db_type = "postgresql"
|
||||
db_dialect = "postgresql"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.session.execute(
|
||||
text(
|
||||
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
)
|
||||
)
|
||||
view_results = self.session.execute(
|
||||
text(
|
||||
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
)
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
self._all_tables = table_results.union(view_results)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT DISTINCT grantee, privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = CURRENT_USER;"""
|
||||
)
|
||||
)
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
try:
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
|
||||
)
|
||||
)
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
except Exception as e:
|
||||
print("postgresql get collation error: ", e)
|
||||
return None
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
try:
|
||||
cursor = self.session.execute(
|
||||
text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")
|
||||
)
|
||||
users = cursor.fetchall()
|
||||
return [user[0] for user in users]
|
||||
except Exception as e:
|
||||
print("postgresql get users error: ", e)
|
||||
return []
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
|
||||
FROM information_schema.columns WHERE table_name = :table_name",
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
|
||||
)
|
||||
)
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
|
||||
SELECT max(a.attnum)
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
|
||||
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
|
||||
"""
|
||||
)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
create_table_query = f"CREATE TABLE {table_name} (\n"
|
||||
for row in rows:
|
||||
create_table_query += f" {row[0]} {row[1]},\n"
|
||||
create_table_query = create_table_query.rstrip(",\n") + "\n)"
|
||||
|
||||
return create_table_query
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
tablses = self.table_simple_info()
|
||||
comments = []
|
||||
for table in tablses:
|
||||
table_name = table[0]
|
||||
table_comment = self.get_show_create_table(table_name)
|
||||
comments.append((table_name, table_comment))
|
||||
return comments
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
def get_database_names(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.session.execute(text("SELECT current_database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
SELECT table_name, string_agg(column_name, ', ') AS schema_info
|
||||
FROM (
|
||||
SELECT c.relname AS table_name, a.attname AS column_name
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
|
||||
WHERE c.relkind = 'r'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
AND n.nspname NOT LIKE 'pg_%'
|
||||
AND n.nspname != 'information_schema'
|
||||
ORDER BY c.relname, a.attnum
|
||||
) sub
|
||||
GROUP BY table_name;
|
||||
"""
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def get_fields(self, table_name, schema_name="public"):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
|
||||
FROM information_schema.columns c
|
||||
LEFT JOIN pg_catalog.pg_description d
|
||||
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
|
||||
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
|
||||
)
|
||||
)
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[0], index[1]) for index in indexes]
|
129
dbgpt/datasource/rdbms/conn_sqlite.py
Normal file
129
dbgpt/datasource/rdbms/conn_sqlite.py
Normal file
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from typing import Optional, Any, Iterable
|
||||
from sqlalchemy import create_engine, text
|
||||
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class SQLiteConnect(RDBMSDatabase):
|
||||
"""Connect SQLite Database fetch MetaData
|
||||
Args:
|
||||
Usage:
|
||||
"""
|
||||
|
||||
db_type: str = "sqlite"
|
||||
db_dialect: str = "sqlite"
|
||||
|
||||
@classmethod
|
||||
def from_file_path(
|
||||
cls, file_path: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> RDBMSDatabase:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
_engine_args["connect_args"] = {"check_same_thread": False}
|
||||
# _engine_args["echo"] = True
|
||||
directory = os.path.dirname(file_path)
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory)
|
||||
return cls(create_engine("sqlite:///" + file_path, **_engine_args), **kwargs)
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
cursor = self.session.execute(text(f"PRAGMA index_list({table_name})"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[1], index[3]) for index in indexes]
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
"""Get table show create table about specified table."""
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}'"
|
||||
)
|
||||
)
|
||||
ans = cursor.fetchall()
|
||||
return ans[0][0]
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
cursor = self.session.execute(text(f"PRAGMA table_info('{table_name}')"))
|
||||
fields = cursor.fetchall()
|
||||
print(fields)
|
||||
return [(field[1], field[2], field[3], field[4], field[5]) for field in fields]
|
||||
|
||||
def get_users(self):
|
||||
return []
|
||||
|
||||
def get_grants(self):
|
||||
return []
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
return "UTF-8"
|
||||
|
||||
def get_charset(self):
|
||||
return "UTF-8"
|
||||
|
||||
def get_database_list(self):
|
||||
return []
|
||||
|
||||
def get_database_names(self):
|
||||
return []
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.session.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
)
|
||||
view_results = self.session.execute(
|
||||
text("SELECT name FROM sqlite_master WHERE type='view'")
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
self._all_tables = table_results.union(view_results)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def _write(self, session, write_sql):
|
||||
print(f"Write[{write_sql}]")
|
||||
result = session.execute(text(write_sql))
|
||||
session.commit()
|
||||
# TODO Subsequent optimization of dynamically specified database submission loss target problem
|
||||
print(f"SQL[{write_sql}], result:{result.rowcount}")
|
||||
return result.rowcount
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
cursor = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT name, sql FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
)
|
||||
)
|
||||
table_comments = cursor.fetchall()
|
||||
return [
|
||||
(table_comment[0], table_comment[1]) for table_comment in table_comments
|
||||
]
|
||||
|
||||
def table_simple_info(self) -> Iterable[str]:
|
||||
_tables_sql = f"""
|
||||
SELECT name FROM sqlite_master WHERE type='table'
|
||||
"""
|
||||
cursor = self.session.execute(text(_tables_sql))
|
||||
tables_results = cursor.fetchall()
|
||||
results = []
|
||||
for row in tables_results:
|
||||
table_name = row[0]
|
||||
_sql = f"""
|
||||
PRAGMA table_info({table_name})
|
||||
"""
|
||||
cursor_colums = self.session.execute(text(_sql))
|
||||
colum_results = cursor_colums.fetchall()
|
||||
table_colums = []
|
||||
for row_col in colum_results:
|
||||
field_info = list(row_col)
|
||||
table_colums.append(field_info[1])
|
||||
|
||||
results.append(f"{table_name}({','.join(table_colums)});")
|
||||
return results
|
150
dbgpt/datasource/rdbms/conn_starrocks.py
Normal file
150
dbgpt/datasource/rdbms/conn_starrocks.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import Iterable, Optional, Any
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import quote_plus as urlquote
|
||||
from dbgpt.datasource.rdbms.base import RDBMSDatabase
|
||||
from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import *
|
||||
|
||||
|
||||
class StarRocksConnect(RDBMSDatabase):
|
||||
driver = "starrocks"
|
||||
db_type = "starrocks"
|
||||
db_dialect = "starrocks"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
db_name = self.get_current_db_name()
|
||||
table_results = self.session.execute(
|
||||
text(
|
||||
f'SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
|
||||
)
|
||||
)
|
||||
# view_results = self.session.execute(text(f'SELECT TABLE_NAME from information_schema.materialized_views where TABLE_SCHEMA="{db_name}"'))
|
||||
table_results = set(row[0] for row in table_results)
|
||||
# view_results = set(row[0] for row in view_results)
|
||||
self._all_tables = table_results
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
def get_grants(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SHOW GRANTS"))
|
||||
grants = cursor.fetchall()
|
||||
if len(grants) == 0:
|
||||
return []
|
||||
if len(grants[0]) == 2:
|
||||
grants_list = [x[1] for x in grants]
|
||||
else:
|
||||
grants_list = [x[2] for x in grants]
|
||||
return grants_list
|
||||
|
||||
def _get_current_version(self):
|
||||
"""Get database current version"""
|
||||
return int(self.session.execute(text("select current_version()")).scalar())
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
# StarRocks 排序是表级别的
|
||||
return None
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
return []
|
||||
|
||||
def get_fields(self, table_name, db_name="database()"):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
if db_name != "database()":
|
||||
db_name = f'"{db_name}"'
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f'select COLUMN_NAME, COLUMN_TYPE, COLUMN_DEFAULT, IS_NULLABLE, COLUMN_COMMENT from information_schema.columns where TABLE_NAME="{table_name}" and TABLE_SCHEMA = {db_name}'
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
|
||||
return "utf-8"
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
# cur = self.session.execute(
|
||||
# text(
|
||||
# f"""show create table {table_name}"""
|
||||
# )
|
||||
# )
|
||||
# rows = cur.fetchone()
|
||||
# create_sql = rows[0]
|
||||
|
||||
# return create_sql
|
||||
# 这里是要表描述, 返回建表语句会导致token过长而失败
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f'SELECT TABLE_COMMENT FROM information_schema.tables where TABLE_NAME="{table_name}" and TABLE_SCHEMA=database()'
|
||||
)
|
||||
)
|
||||
table = cur.fetchone()
|
||||
if table:
|
||||
return str(table[0])
|
||||
else:
|
||||
return ""
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
if not db_name:
|
||||
db_name = self.get_current_db_name()
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f'SELECT TABLE_NAME,TABLE_COMMENT FROM information_schema.tables where TABLE_SCHEMA="{db_name}"'
|
||||
)
|
||||
)
|
||||
tables = cur.fetchall()
|
||||
return [(table[0], table[1]) for table in tables]
|
||||
|
||||
def get_database_list(self):
|
||||
return self.get_database_names()
|
||||
|
||||
def get_database_names(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SHOW DATABASES;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "sys", "_statistics_", "dataease"]
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.session.execute(text("select database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
SELECT concat(TABLE_NAME,"(",group_concat(COLUMN_NAME,","),");") FROM information_schema.columns where TABLE_SCHEMA=database()
|
||||
GROUP BY TABLE_NAME
|
||||
"""
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return [x[0] for x in results]
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SHOW INDEX FROM {table_name}"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[2], index[4]) for index in indexes]
|
0
dbgpt/datasource/rdbms/dialect/__init__.py
Normal file
0
dbgpt/datasource/rdbms/dialect/__init__.py
Normal file
14
dbgpt/datasource/rdbms/dialect/starrocks/__init__.py
Normal file
14
dbgpt/datasource/rdbms/dialect/starrocks/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#! /usr/bin/python3
|
||||
# Copyright 2021-present StarRocks, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https:#www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@@ -0,0 +1,22 @@
|
||||
#! /usr/bin/python3
|
||||
# Copyright 2021-present StarRocks, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https:#www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlalchemy.dialects import registry
|
||||
|
||||
registry.register(
|
||||
"starrocks",
|
||||
"dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy.dialect",
|
||||
"StarRocksDialect",
|
||||
)
|
104
dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py
Normal file
104
dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/datatype.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, List, Any, Type, Dict
|
||||
|
||||
from sqlalchemy import Numeric, Integer, Float
|
||||
from sqlalchemy.sql import sqltypes
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TINYINT(Integer): # pylint: disable=no-init
|
||||
__visit_name__ = "TINYINT"
|
||||
|
||||
|
||||
class LARGEINT(Integer): # pylint: disable=no-init
|
||||
__visit_name__ = "LARGEINT"
|
||||
|
||||
|
||||
class DOUBLE(Float): # pylint: disable=no-init
|
||||
__visit_name__ = "DOUBLE"
|
||||
|
||||
|
||||
class HLL(Numeric): # pylint: disable=no-init
|
||||
__visit_name__ = "HLL"
|
||||
|
||||
|
||||
class BITMAP(Numeric): # pylint: disable=no-init
|
||||
__visit_name__ = "BITMAP"
|
||||
|
||||
|
||||
class PERCENTILE(Numeric): # pylint: disable=no-init
|
||||
__visit_name__ = "PERCENTILE"
|
||||
|
||||
|
||||
class ARRAY(TypeEngine): # pylint: disable=no-init
|
||||
__visit_name__ = "ARRAY"
|
||||
|
||||
@property
|
||||
def python_type(self) -> Optional[Type[List[Any]]]:
|
||||
return list
|
||||
|
||||
|
||||
class MAP(TypeEngine): # pylint: disable=no-init
|
||||
__visit_name__ = "MAP"
|
||||
|
||||
@property
|
||||
def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
|
||||
return dict
|
||||
|
||||
|
||||
class STRUCT(TypeEngine): # pylint: disable=no-init
|
||||
__visit_name__ = "STRUCT"
|
||||
|
||||
@property
|
||||
def python_type(self) -> Optional[Type[Any]]:
|
||||
return None
|
||||
|
||||
|
||||
_type_map = {
|
||||
# === Boolean ===
|
||||
"boolean": sqltypes.BOOLEAN,
|
||||
# === Integer ===
|
||||
"tinyint": sqltypes.SMALLINT,
|
||||
"smallint": sqltypes.SMALLINT,
|
||||
"int": sqltypes.INTEGER,
|
||||
"bigint": sqltypes.BIGINT,
|
||||
"largeint": LARGEINT,
|
||||
# === Floating-point ===
|
||||
"float": sqltypes.FLOAT,
|
||||
"double": DOUBLE,
|
||||
# === Fixed-precision ===
|
||||
"decimal": sqltypes.DECIMAL,
|
||||
# === String ===
|
||||
"varchar": sqltypes.VARCHAR,
|
||||
"char": sqltypes.CHAR,
|
||||
"json": sqltypes.JSON,
|
||||
# === Date and time ===
|
||||
"date": sqltypes.DATE,
|
||||
"datetime": sqltypes.DATETIME,
|
||||
"timestamp": sqltypes.DATETIME,
|
||||
# === Structural ===
|
||||
"array": ARRAY,
|
||||
"map": MAP,
|
||||
"struct": STRUCT,
|
||||
"hll": HLL,
|
||||
"percentile": PERCENTILE,
|
||||
"bitmap": BITMAP,
|
||||
}
|
||||
|
||||
|
||||
def parse_sqltype(type_str: str) -> TypeEngine:
|
||||
type_str = type_str.strip().lower()
|
||||
match = re.match(r"^(?P<type>\w+)\s*(?:\((?P<options>.*)\))?", type_str)
|
||||
if not match:
|
||||
logger.warning(f"Could not parse type name '{type_str}'")
|
||||
return sqltypes.NULLTYPE
|
||||
type_name = match.group("type")
|
||||
|
||||
if type_name not in _type_map:
|
||||
logger.warning(f"Did not recognize type '{type_name}'")
|
||||
return sqltypes.NULLTYPE
|
||||
type_class = _type_map[type_name]
|
||||
return type_class()
|
173
dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py
Normal file
173
dbgpt/datasource/rdbms/dialect/starrocks/sqlalchemy/dialect.py
Normal file
@@ -0,0 +1,173 @@
|
||||
#! /usr/bin/python3
|
||||
# Copyright 2021-present StarRocks, Inc. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https:#www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from sqlalchemy import log, exc, text
|
||||
from sqlalchemy.dialects.mysql.pymysql import MySQLDialect_pymysql
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from dbgpt.datasource.rdbms.dialect.starrocks.sqlalchemy import datatype
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class StarRocksDialect(MySQLDialect_pymysql):
|
||||
# Caching
|
||||
# Warnings are generated by SQLAlchmey if this flag is not explicitly set
|
||||
# and tests are needed before being enabled
|
||||
supports_statement_cache = False
|
||||
|
||||
name = "starrocks"
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
super(StarRocksDialect, self).__init__(*args, **kw)
|
||||
|
||||
def has_table(self, connection, table_name, schema=None, **kw):
|
||||
self._ensure_has_table_connection(connection)
|
||||
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
|
||||
assert schema is not None
|
||||
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
full_name = quote(table_name)
|
||||
if schema:
|
||||
full_name = "{}.{}".format(quote(schema), full_name)
|
||||
|
||||
res = connection.execute(text(f"DESCRIBE {full_name}"))
|
||||
return res.first() is not None
|
||||
|
||||
def get_schema_names(self, connection, **kw):
|
||||
rp = connection.exec_driver_sql("SHOW schemas")
|
||||
return [r[0] for r in rp]
|
||||
|
||||
def get_table_names(self, connection, schema=None, **kw):
|
||||
"""Return a Unicode SHOW TABLES from a given schema."""
|
||||
if schema is not None:
|
||||
current_schema = schema
|
||||
else:
|
||||
current_schema = self.default_schema_name
|
||||
|
||||
charset = self._connection_charset
|
||||
|
||||
rp = connection.exec_driver_sql(
|
||||
"SHOW FULL TABLES FROM %s"
|
||||
% self.identifier_preparer.quote_identifier(current_schema)
|
||||
)
|
||||
|
||||
return [
|
||||
row[0]
|
||||
for row in self._compat_fetchall(rp, charset=charset)
|
||||
if row[1] == "BASE TABLE"
|
||||
]
|
||||
|
||||
def get_view_names(self, connection, schema=None, **kw):
|
||||
if schema is None:
|
||||
schema = self.default_schema_name
|
||||
charset = self._connection_charset
|
||||
rp = connection.exec_driver_sql(
|
||||
"SHOW FULL TABLES FROM %s"
|
||||
% self.identifier_preparer.quote_identifier(schema)
|
||||
)
|
||||
return [
|
||||
row[0]
|
||||
for row in self._compat_fetchall(rp, charset=charset)
|
||||
if row[1] in ("VIEW", "SYSTEM VIEW")
|
||||
]
|
||||
|
||||
def get_columns(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not self.has_table(connection, table_name, schema):
|
||||
raise exc.NoSuchTableError(f"schema={schema}, table={table_name}")
|
||||
schema = schema or self._get_default_schema_name(connection)
|
||||
|
||||
quote = self.identifier_preparer.quote_identifier
|
||||
full_name = quote(table_name)
|
||||
if schema:
|
||||
full_name = "{}.{}".format(quote(schema), full_name)
|
||||
|
||||
res = connection.execute(text(f"SHOW COLUMNS FROM {full_name}"))
|
||||
columns = []
|
||||
for record in res:
|
||||
column = dict(
|
||||
name=record.Field,
|
||||
type=datatype.parse_sqltype(record.Type),
|
||||
nullable=record.Null == "YES",
|
||||
default=record.Default,
|
||||
)
|
||||
columns.append(column)
|
||||
return columns
|
||||
|
||||
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
|
||||
return { # type: ignore # pep-655 not supported
|
||||
"name": None,
|
||||
"constrained_columns": [],
|
||||
}
|
||||
|
||||
def get_unique_constraints(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> List[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def get_check_constraints(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> List[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def get_foreign_keys(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> List[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def get_primary_keys(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> List[str]:
|
||||
pk = self.get_pk_constraint(connection, table_name, schema)
|
||||
return pk.get("constrained_columns") # type: ignore
|
||||
|
||||
def get_indexes(self, connection, table_name, schema=None, **kw):
|
||||
return []
|
||||
|
||||
def has_sequence(
|
||||
self, connection: Connection, sequence_name: str, schema: str = None, **kw
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
def get_sequence_names(
|
||||
self, connection: Connection, schema: str = None, **kw
|
||||
) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_temp_view_names(
|
||||
self, connection: Connection, schema: str = None, **kw
|
||||
) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_temp_table_names(
|
||||
self, connection: Connection, schema: str = None, **kw
|
||||
) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_table_options(self, connection, table_name, schema=None, **kw):
|
||||
return {}
|
||||
|
||||
def get_table_comment(
|
||||
self, connection: Connection, table_name: str, schema: str = None, **kw
|
||||
) -> Dict[str, Any]:
|
||||
return dict(text=None)
|
0
dbgpt/datasource/rdbms/tests/__init__.py
Normal file
0
dbgpt/datasource/rdbms/tests/__init__.py
Normal file
137
dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
Normal file
137
dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Run unit test with command: pytest dbgpt/datasource/rdbms/tests/test_conn_sqlite.py
|
||||
"""
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from dbgpt.datasource.rdbms.conn_sqlite import SQLiteConnect
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
temp_db_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
temp_db_file.close()
|
||||
conn = SQLiteConnect.from_file_path(temp_db_file.name)
|
||||
yield conn
|
||||
os.unlink(temp_db_file.name)
|
||||
|
||||
|
||||
def test_get_table_names(db):
|
||||
assert list(db.get_table_names()) == []
|
||||
|
||||
|
||||
def test_get_table_info(db):
|
||||
assert db.get_table_info() == ""
|
||||
|
||||
|
||||
def test_get_table_info_with_table(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
||||
print(db._sync_tables_from_db())
|
||||
table_info = db.get_table_info()
|
||||
assert "CREATE TABLE test" in table_info
|
||||
|
||||
|
||||
def test_run_sql(db):
|
||||
result = db.run(db.session, "CREATE TABLE test (id INTEGER);")
|
||||
assert result[0] == ("cid", "name", "type", "notnull", "dflt_value", "pk")
|
||||
|
||||
|
||||
def test_run_no_throw(db):
|
||||
assert db.run_no_throw(db.session, "this is a error sql").startswith("Error:")
|
||||
|
||||
|
||||
def test_get_indexes(db):
|
||||
db.run(db.session, "CREATE TABLE test (name TEXT);")
|
||||
db.run(db.session, "CREATE INDEX idx_name ON test(name);")
|
||||
assert db.get_indexes("test") == [("idx_name", "c")]
|
||||
|
||||
|
||||
def test_get_indexes_empty(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_indexes("test") == []
|
||||
|
||||
|
||||
def test_get_show_create_table(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert (
|
||||
db.get_show_create_table("test") == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
|
||||
)
|
||||
|
||||
|
||||
def test_get_fields(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_fields("test") == [("id", "INTEGER", 0, None, 1)]
|
||||
|
||||
|
||||
def test_get_charset(db):
|
||||
assert db.get_charset() == "UTF-8"
|
||||
|
||||
|
||||
def test_get_collation(db):
|
||||
assert db.get_collation() == "UTF-8"
|
||||
|
||||
|
||||
def test_table_simple_info(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.table_simple_info() == ["test(id);"]
|
||||
|
||||
|
||||
def test_get_table_info_no_throw(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_table_info_no_throw("xxxx_table").startswith("Error:")
|
||||
|
||||
|
||||
def test_query_ex(db):
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
db.run(db.session, "insert into test(id) values (1)")
|
||||
db.run(db.session, "insert into test(id) values (2)")
|
||||
field_names, result = db.query_ex(db.session, "select * from test")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,), (2,)]
|
||||
|
||||
field_names, result = db.query_ex(db.session, "select * from test", fetch="one")
|
||||
assert field_names == ["id"]
|
||||
assert result == [(1,)]
|
||||
|
||||
|
||||
def test_convert_sql_write_to_select(db):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
||||
def test_get_grants(db):
|
||||
assert db.get_grants() == []
|
||||
|
||||
|
||||
def test_get_users(db):
|
||||
assert db.get_users() == []
|
||||
|
||||
|
||||
def test_get_table_comments(db):
|
||||
assert db.get_table_comments() == []
|
||||
db.run(db.session, "CREATE TABLE test (id INTEGER PRIMARY KEY);")
|
||||
assert db.get_table_comments() == [
|
||||
("test", "CREATE TABLE test (id INTEGER PRIMARY KEY)")
|
||||
]
|
||||
|
||||
|
||||
def test_get_database_list(db):
|
||||
db.get_database_list() == []
|
||||
|
||||
|
||||
def test_get_database_names(db):
|
||||
db.get_database_names() == []
|
||||
|
||||
|
||||
def test_db_dir_exist_dir():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
new_dir = os.path.join(temp_dir, "new_dir")
|
||||
file_path = os.path.join(new_dir, "sqlite.db")
|
||||
db = SQLiteConnect.from_file_path(file_path)
|
||||
assert os.path.exists(new_dir) == True
|
||||
assert list(db.get_table_names()) == []
|
||||
with tempfile.TemporaryDirectory() as existing_dir:
|
||||
file_path = os.path.join(existing_dir, "sqlite.db")
|
||||
db = SQLiteConnect.from_file_path(file_path)
|
||||
assert os.path.exists(existing_dir) == True
|
||||
assert list(db.get_table_names()) == []
|
8
dbgpt/datasource/redis.py
Normal file
8
dbgpt/datasource/redis.py
Normal file
@@ -0,0 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
|
||||
|
||||
class RedisConnector:
|
||||
"""RedisConnector"""
|
||||
|
||||
pass
|
Reference in New Issue
Block a user