mirror of
https://github.com/csunny/DB-GPT.git
synced 2026-01-29 21:49:35 +00:00
add plugin mode
This commit is contained in:
@@ -30,16 +30,16 @@ class Database:
|
||||
"""SQLAlchemy wrapper around a database."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
self,
|
||||
engine,
|
||||
schema: Optional[str] = None,
|
||||
metadata: Optional[MetaData] = None,
|
||||
ignore_tables: Optional[List[str]] = None,
|
||||
include_tables: Optional[List[str]] = None,
|
||||
sample_rows_in_table_info: int = 3,
|
||||
indexes_in_table_info: bool = False,
|
||||
custom_table_info: Optional[dict] = None,
|
||||
view_support: bool = False,
|
||||
):
|
||||
"""Create engine from database URI."""
|
||||
self._engine = engine
|
||||
@@ -119,7 +119,7 @@ class Database:
|
||||
|
||||
@classmethod
|
||||
def from_uri(
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
|
||||
) -> Database:
|
||||
"""Construct a SQLAlchemy engine from URI."""
|
||||
_engine_args = engine_args or {}
|
||||
@@ -148,7 +148,7 @@ class Database:
|
||||
|
||||
self._metadata = MetaData()
|
||||
# sql = f"use {db_name}"
|
||||
sql = text(f'use `{db_name}`')
|
||||
sql = text(f"use `{db_name}`")
|
||||
session.execute(sql)
|
||||
|
||||
# 处理表信息数据
|
||||
@@ -159,13 +159,17 @@ class Database:
|
||||
# tables list if view_support is True
|
||||
self._all_tables = set(
|
||||
self._inspector.get_table_names(schema=db_name)
|
||||
+ (self._inspector.get_view_names(schema=db_name) if self.view_support else [])
|
||||
+ (
|
||||
self._inspector.get_view_names(schema=db_name)
|
||||
if self.view_support
|
||||
else []
|
||||
)
|
||||
)
|
||||
|
||||
return session
|
||||
|
||||
def get_current_db_name(self, session) -> str:
|
||||
return session.execute(text('SELECT DATABASE()')).scalar()
|
||||
return session.execute(text("SELECT DATABASE()")).scalar()
|
||||
|
||||
def table_simple_info(self, session):
|
||||
_sql = f"""
|
||||
@@ -201,7 +205,7 @@ class Database:
|
||||
tbl
|
||||
for tbl in self._metadata.sorted_tables
|
||||
if tbl.name in set(all_table_names)
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
||||
]
|
||||
|
||||
tables = []
|
||||
@@ -214,7 +218,7 @@ class Database:
|
||||
create_table = str(CreateTable(table).compile(self._engine))
|
||||
table_info = f"{create_table.rstrip()}"
|
||||
has_extra_info = (
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
self._indexes_in_table_info or self._sample_rows_in_table_info
|
||||
)
|
||||
if has_extra_info:
|
||||
table_info += "\n\n/*"
|
||||
@@ -303,6 +307,10 @@ class Database:
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(' show databases;'))
|
||||
cursor = session.execute(text(" show databases;"))
|
||||
results = cursor.fetchall()
|
||||
return [d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]]
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user