Files
DB-GPT/pilot/common/sql_database.py
2023-05-29 19:32:20 +08:00

317 lines
12 KiB
Python

from __future__ import annotations
import warnings
from typing import Any, Iterable, List, Optional
from pydantic import BaseModel, Field, root_validator, validator, Extra
from abc import ABC, abstractmethod
import sqlalchemy
from sqlalchemy import (
MetaData,
Table,
create_engine,
inspect,
select,
text,
)
from sqlalchemy.engine import CursorResult, Engine
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError
from sqlalchemy.schema import CreateTable
from sqlalchemy.orm import sessionmaker, scoped_session
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str:
return (
f'Name: {index["name"]}, Unique: {index["unique"]},'
f' Columns: {str(index["column_names"])}'
)
class Database:
"""SQLAlchemy wrapper around a database."""
def __init__(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
session_factory = sessionmaker(bind=engine)
Session = scoped_session(session_factory)
self._db_sessions = Session
self._all_tables = set()
self.view_support = False
self._usable_tables = set()
self._include_tables = set()
self._ignore_tables = set()
self._custom_table_info = set()
self._indexes_in_table_info = set()
self._usable_tables = set()
self._usable_tables = set()
self._sample_rows_in_table_info = set()
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
# self._all_tables = set(
# self._inspector.get_table_names(schema=schema)
# + (self._inspector.get_view_names(schema=schema) if view_support else [])
# )
# self._include_tables = set(include_tables) if include_tables else set()
# if self._include_tables:
# missing_tables = self._include_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"include_tables {missing_tables} not found in database"
# )
# self._ignore_tables = set(ignore_tables) if ignore_tables else set()
# if self._ignore_tables:
# missing_tables = self._ignore_tables - self._all_tables
# if missing_tables:
# raise ValueError(
# f"ignore_tables {missing_tables} not found in database"
# )
# usable_tables = self.get_usable_table_names()
# self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
# if not isinstance(sample_rows_in_table_info, int):
# raise TypeError("sample_rows_in_table_info must be an integer")
#
# self._sample_rows_in_table_info = sample_rows_in_table_info
# self._indexes_in_table_info = indexes_in_table_info
#
# self._custom_table_info = custom_table_info
# if self._custom_table_info:
# if not isinstance(self._custom_table_info, dict):
# raise TypeError(
# "table_info must be a dictionary with table names as keys and the "
# "desired table info as values"
# )
# # only keep the tables that are also present in the database
# intersection = set(self._custom_table_info).intersection(self._all_tables)
# self._custom_table_info = dict(
# (table, self._custom_table_info[table])
# for table in self._custom_table_info
# if table in intersection
# )
# self._metadata = metadata or MetaData()
# # # including view support if view_support = true
# self._metadata.reflect(
# views=view_support,
# bind=self._engine,
# only=list(self._usable_tables),
# schema=self._schema,
# )
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> Database:
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
return self._all_tables - self._ignore_tables
def get_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
warnings.warn(
"This method is deprecated - please use `get_usable_table_names`."
)
return self.get_usable_table_names()
def get_session(self, db_name: str):
session = self._db_sessions()
self._metadata = MetaData()
# sql = f"use {db_name}"
sql = text(f"use `{db_name}`")
session.execute(sql)
# 处理表信息数据
self._metadata.reflect(bind=self._engine, schema=db_name)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=db_name)
+ (
self._inspector.get_view_names(schema=db_name)
if self.view_support
else []
)
)
return session
def get_current_db_name(self, session) -> str:
return session.execute(text("SELECT DATABASE()")).scalar()
def table_simple_info(self, session):
_sql = f"""
select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME;
"""
cursor = session.execute(text(_sql))
results = cursor.fetchall()
return results
@property
def table_info(self) -> str:
"""Information about all tables in the database."""
return self.get_table_info()
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
meta_tables = [
tbl
for tbl in self._metadata.sorted_tables
if tbl.name in set(all_table_names)
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
]
tables = []
for table in meta_tables:
if self._custom_table_info and table.name in self._custom_table_info:
tables.append(self._custom_table_info[table.name])
continue
# add create table command
create_table = str(CreateTable(table).compile(self._engine))
table_info = f"{create_table.rstrip()}"
has_extra_info = (
self._indexes_in_table_info or self._sample_rows_in_table_info
)
if has_extra_info:
table_info += "\n\n/*"
if self._indexes_in_table_info:
table_info += f"\n{self._get_table_indexes(table)}\n"
if self._sample_rows_in_table_info:
table_info += f"\n{self._get_sample_rows(table)}\n"
if has_extra_info:
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_rows(self, table: Table) -> str:
# build the select command
command = select(table).limit(self._sample_rows_in_table_info)
# save the columns in string format
columns_str = "\t".join([col.name for col in table.columns])
try:
# get the sample rows
with self._engine.connect() as connection:
sample_rows_result: CursorResult = connection.execute(command)
# shorten values in the sample rows
sample_rows = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
# in some dialects when there are no rows in the table a
# 'ProgrammingError' is returned
except ProgrammingError:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _get_table_indexes(self, table: Table) -> str:
indexes = self._inspector.get_indexes(table.name)
indexes_formatted = "\n".join(map(_format_index, indexes))
return f"Table Indexes:\n{indexes_formatted}"
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables."""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
cursor = session.execute(text(command))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = tuple(i[0:] for i in cursor.keys())
result = list(result)
result.insert(0, field_names)
return result
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(session, command, fetch)
except SQLAlchemyError as e:
"""Format the error message"""
return f"Error: {e}"
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text(" show databases;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"]
]