from __future__ import annotations import warnings from typing import Any, Iterable, List, Optional from pydantic import BaseModel, Field, root_validator, validator, Extra from abc import ABC, abstractmethod import sqlalchemy from sqlalchemy import ( MetaData, Table, create_engine, inspect, select, text, ) from sqlalchemy.engine import CursorResult, Engine from sqlalchemy.exc import ProgrammingError, SQLAlchemyError from sqlalchemy.schema import CreateTable from sqlalchemy.orm import sessionmaker, scoped_session from pilot.connections.base import BaseConnect from pilot.configs.config import Config CFG = Config() def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: return ( f'Name: {index["name"]}, Unique: {index["unique"]},' f' Columns: {str(index["column_names"])}' ) class RDBMSDatabase(BaseConnect): """SQLAlchemy wrapper around a database.""" def __init__( self, engine, schema: Optional[str] = None, metadata: Optional[MetaData] = None, ignore_tables: Optional[List[str]] = None, include_tables: Optional[List[str]] = None, ): """Create engine from database URI.""" self._engine = engine self._schema = schema if include_tables and ignore_tables: raise ValueError("Cannot specify both include_tables and ignore_tables") self._inspector = inspect(self._engine) session_factory = sessionmaker(bind=engine) Session = scoped_session(session_factory) self._db_sessions = Session @classmethod def from_config(cls) -> RDBMSDatabase: """ Todo password encryption Returns: """ return cls.from_uri_db( cls, CFG.LOCAL_DB_HOST, CFG.LOCAL_DB_PORT, CFG.LOCAL_DB_USER, CFG.LOCAL_DB_PASSWORD, engine_args={"pool_size": 10, "pool_recycle": 3600, "echo": True}, ) @classmethod def from_uri_db( cls, host: str, port: int, user: str, pwd: str, db_name: str = None, engine_args: Optional[dict] = None, **kwargs: Any, ) -> RDBMSDatabase: db_url: str = ( cls.connect_driver + "://" + CFG.LOCAL_DB_USER + ":" + CFG.LOCAL_DB_PASSWORD + "@" + CFG.LOCAL_DB_HOST + ":" + str(CFG.LOCAL_DB_PORT) ) if cls.dialect: db_url = cls.dialect + "+" + db_url if db_name: db_url = db_url + "/" + db_name return cls.from_uri(db_url, engine_args, **kwargs) @classmethod def from_uri( cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any ) -> RDBMSDatabase: """Construct a SQLAlchemy engine from URI.""" _engine_args = engine_args or {} return cls(create_engine(database_uri, **_engine_args), **kwargs) @property def dialect(self) -> str: """Return string representation of dialect to use.""" return self._engine.dialect.name def get_usable_table_names(self) -> Iterable[str]: """Get names of tables available.""" if self._include_tables: return self._include_tables return self._all_tables - self._ignore_tables def get_table_names(self) -> Iterable[str]: """Get names of tables available.""" warnings.warn( "This method is deprecated - please use `get_usable_table_names`." ) return self.get_usable_table_names() def get_session(self, db_name: str): session = self._db_sessions() self._metadata = MetaData() # sql = f"use {db_name}" sql = text(f"use `{db_name}`") session.execute(sql) # 处理表信息数据 self._metadata.reflect(bind=self._engine, schema=db_name) # including view support by adding the views as well as tables to the all # tables list if view_support is True self._all_tables = set( self._inspector.get_table_names(schema=db_name) + ( self._inspector.get_view_names(schema=db_name) if self.view_support else [] ) ) return session def get_current_db_name(self, session) -> str: return session.execute(text("SELECT DATABASE()")).scalar() def table_simple_info(self, session): _sql = f""" select concat(table_name, "(" , group_concat(column_name), ")") as schema_info from information_schema.COLUMNS where table_schema="{self.get_current_db_name(session)}" group by TABLE_NAME; """ cursor = session.execute(text(_sql)) results = cursor.fetchall() return results @property def table_info(self) -> str: """Information about all tables in the database.""" return self.get_table_info() def get_table_info(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables. Follows best practices as specified in: Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498) If `sample_rows_in_table_info`, the specified number of sample rows will be appended to each table description. This can increase performance as demonstrated in the paper. """ all_table_names = self.get_usable_table_names() if table_names is not None: missing_tables = set(table_names).difference(all_table_names) if missing_tables: raise ValueError(f"table_names {missing_tables} not found in database") all_table_names = table_names meta_tables = [ tbl for tbl in self._metadata.sorted_tables if tbl.name in set(all_table_names) and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) ] tables = [] for table in meta_tables: if self._custom_table_info and table.name in self._custom_table_info: tables.append(self._custom_table_info[table.name]) continue # add create table command create_table = str(CreateTable(table).compile(self._engine)) table_info = f"{create_table.rstrip()}" has_extra_info = ( self._indexes_in_table_info or self._sample_rows_in_table_info ) if has_extra_info: table_info += "\n\n/*" if self._indexes_in_table_info: table_info += f"\n{self._get_table_indexes(table)}\n" if self._sample_rows_in_table_info: table_info += f"\n{self._get_sample_rows(table)}\n" if has_extra_info: table_info += "*/" tables.append(table_info) final_str = "\n\n".join(tables) return final_str def _get_sample_rows(self, table: Table) -> str: # build the select command command = select(table).limit(self._sample_rows_in_table_info) # save the columns in string format columns_str = "\t".join([col.name for col in table.columns]) try: # get the sample rows with self._engine.connect() as connection: sample_rows_result: CursorResult = connection.execute(command) # shorten values in the sample rows sample_rows = list( map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) ) # save the sample rows in string format sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) # in some dialects when there are no rows in the table a # 'ProgrammingError' is returned except ProgrammingError: sample_rows_str = "" return ( f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" f"{columns_str}\n" f"{sample_rows_str}" ) def _get_table_indexes(self, table: Table) -> str: indexes = self._inspector.get_indexes(table.name) indexes_formatted = "\n".join(map(_format_index, indexes)) return f"Table Indexes:\n{indexes_formatted}" def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: """Get information about specified tables.""" try: return self.get_table_info(table_names) except ValueError as e: """Format the error message""" return f"Error: {e}" def run(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" cursor = session.execute(text(command)) if cursor.returns_rows: if fetch == "all": result = cursor.fetchall() elif fetch == "one": result = cursor.fetchone()[0] # type: ignore else: raise ValueError("Fetch parameter must be either 'one' or 'all'") field_names = tuple(i[0:] for i in cursor.keys()) result = list(result) result.insert(0, field_names) return result def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. If the statement returns rows, a string of the results is returned. If the statement returns no rows, an empty string is returned. If the statement throws an error, the error message is returned. """ try: return self.run(session, command, fetch) except SQLAlchemyError as e: """Format the error message""" return f"Error: {e}" def get_database_list(self): session = self._db_sessions() cursor = session.execute(text(" show databases;")) results = cursor.fetchall() return [ d[0] for d in results if d[0] not in ["information_schema", "performance_schema", "sys", "mysql"] ]