style: fix code style with black

This commit is contained in:
lozzow 2023-09-25 14:29:16 +00:00
parent 6a03c24ee7
commit 14dc5dac04
2 changed files with 46 additions and 31 deletions

View File

@ -5,9 +5,9 @@ from pilot.connections.rdbms.base import RDBMSDatabase
class PostgreSQLDatabase(RDBMSDatabase): class PostgreSQLDatabase(RDBMSDatabase):
driver = 'postgresql+psycopg2' driver = "postgresql+psycopg2"
db_type = "postgresql" db_type = "postgresql"
db_dialect = 'postgresql' db_dialect = "postgresql"
@classmethod @classmethod
def from_uri_db( def from_uri_db(
@ -37,10 +37,14 @@ class PostgreSQLDatabase(RDBMSDatabase):
def _sync_tables_from_db(self) -> Iterable[str]: def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute( table_results = self.session.execute(
text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") text(
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
)
) )
view_results = self.session.execute( view_results = self.session.execute(
text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") text(
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
)
) )
table_results = set(row[0] for row in table_results) table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results) view_results = set(row[0] for row in view_results)
@ -48,13 +52,16 @@ class PostgreSQLDatabase(RDBMSDatabase):
self._metadata.reflect(bind=self._engine) self._metadata.reflect(bind=self._engine)
return self._all_tables return self._all_tables
def get_grants(self): def get_grants(self):
session = self._db_sessions() session = self._db_sessions()
cursor = session.execute(text(f""" cursor = session.execute(
text(
f"""
SELECT DISTINCT grantee, privilege_type SELECT DISTINCT grantee, privilege_type
FROM information_schema.role_table_grants FROM information_schema.role_table_grants
WHERE grantee = CURRENT_USER;""")) WHERE grantee = CURRENT_USER;"""
)
)
grants = cursor.fetchall() grants = cursor.fetchall()
return grants return grants
@ -62,7 +69,11 @@ class PostgreSQLDatabase(RDBMSDatabase):
"""Get collation.""" """Get collation."""
try: try:
session = self._db_sessions() session = self._db_sessions()
cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();")) cursor = session.execute(
text(
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
)
)
collation = cursor.fetchone()[0] collation = cursor.fetchone()[0]
return collation return collation
except Exception as e: except Exception as e:
@ -72,7 +83,9 @@ class PostgreSQLDatabase(RDBMSDatabase):
def get_users(self): def get_users(self):
"""Get user info.""" """Get user info."""
try: try:
cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")) cursor = self.session.execute(
text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")
)
users = cursor.fetchall() users = cursor.fetchall()
return [user[0] for user in users] return [user[0] for user in users]
except Exception as e: except Exception as e:
@ -95,11 +108,14 @@ class PostgreSQLDatabase(RDBMSDatabase):
def get_charset(self): def get_charset(self):
"""Get character_set.""" """Get character_set."""
session = self._db_sessions() session = self._db_sessions()
cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();")) cursor = session.execute(
text(
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
)
)
character_set = cursor.fetchone()[0] character_set = cursor.fetchone()[0]
return character_set return character_set
def get_show_create_table(self, table_name): def get_show_create_table(self, table_name):
cur = self.session.execute( cur = self.session.execute(
text( text(
@ -119,7 +135,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
create_table_query = f"CREATE TABLE {table_name} (\n" create_table_query = f"CREATE TABLE {table_name} (\n"
for row in rows: for row in rows:
create_table_query += f" {row[0]} {row[1]},\n" create_table_query += f" {row[0]} {row[1]},\n"
create_table_query = create_table_query.rstrip(',\n') + "\n)" create_table_query = create_table_query.rstrip(",\n") + "\n)"
return create_table_query return create_table_query
@ -137,9 +153,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
cursor = session.execute(text("SELECT datname FROM pg_database;")) cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall() results = cursor.fetchall()
return [ return [
d[0] d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
for d in results
if d[0] not in ["template0", "template1", "postgres"]
] ]
def get_database_names(self): def get_database_names(self):
@ -147,9 +161,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
cursor = session.execute(text("SELECT datname FROM pg_database;")) cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall() results = cursor.fetchall()
return [ return [
d[0] d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
for d in results
if d[0] not in ["template0", "template1", "postgres"]
] ]
def get_current_db_name(self) -> str: def get_current_db_name(self) -> str:
@ -176,7 +188,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
results = cursor.fetchall() results = cursor.fetchall()
return results return results
def get_fields(self, table_name, schema_name='public'): def get_fields(self, table_name, schema_name="public"):
"""Get column fields about specified table.""" """Get column fields about specified table."""
session = self._db_sessions() session = self._db_sessions()
cursor = session.execute( cursor = session.execute(
@ -193,10 +205,13 @@ class PostgreSQLDatabase(RDBMSDatabase):
fields = cursor.fetchall() fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_indexes(self, table_name): def get_indexes(self, table_name):
"""Get table indexes about specified table.""" """Get table indexes about specified table."""
session = self._db_sessions() session = self._db_sessions()
cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'")) cursor = session.execute(
text(
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
)
)
indexes = cursor.fetchall() indexes = cursor.fetchall()
return [(index[0], index[1]) for index in indexes] return [(index[0], index[1]) for index in indexes]