From 14dc5dac04aaa05b4ffe7ad772dc5546de7f9b2b Mon Sep 17 00:00:00 2001 From: lozzow Date: Mon, 25 Sep 2023 14:29:16 +0000 Subject: [PATCH] style: fix code style with black --- pilot/connections/rdbms/conn_postgresql.py | 75 +++++++++++++--------- setup.py | 2 +- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/pilot/connections/rdbms/conn_postgresql.py b/pilot/connections/rdbms/conn_postgresql.py index aab7e4ccc..2eeca4f04 100644 --- a/pilot/connections/rdbms/conn_postgresql.py +++ b/pilot/connections/rdbms/conn_postgresql.py @@ -5,9 +5,9 @@ from pilot.connections.rdbms.base import RDBMSDatabase class PostgreSQLDatabase(RDBMSDatabase): - driver = 'postgresql+psycopg2' + driver = "postgresql+psycopg2" db_type = "postgresql" - db_dialect = 'postgresql' + db_dialect = "postgresql" @classmethod def from_uri_db( @@ -34,13 +34,17 @@ class PostgreSQLDatabase(RDBMSDatabase): + db_name ) return cls.from_uri(db_url, engine_args, **kwargs) - + def _sync_tables_from_db(self) -> Iterable[str]: table_results = self.session.execute( - text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") + text( + "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + ) ) view_results = self.session.execute( - text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") + text( + "SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + ) ) table_results = set(row[0] for row in table_results) view_results = set(row[0] for row in view_results) @@ -48,31 +52,40 @@ class PostgreSQLDatabase(RDBMSDatabase): self._metadata.reflect(bind=self._engine) return self._all_tables - def get_grants(self): session = self._db_sessions() - cursor = session.execute(text(f""" + cursor = session.execute( + text( + f""" SELECT DISTINCT grantee, privilege_type FROM information_schema.role_table_grants - WHERE grantee = CURRENT_USER;""")) + WHERE grantee = CURRENT_USER;""" + ) + ) grants = cursor.fetchall() return grants - + def get_collation(self): """Get collation.""" try: session = self._db_sessions() - cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();")) + cursor = session.execute( + text( + "SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();" + ) + ) collation = cursor.fetchone()[0] return collation except Exception as e: print("postgresql get collation error: ", e) return None - + def get_users(self): """Get user info.""" try: - cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")) + cursor = self.session.execute( + text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';") + ) users = cursor.fetchall() return [user[0] for user in users] except Exception as e: @@ -91,16 +104,19 @@ class PostgreSQLDatabase(RDBMSDatabase): ) fields = cursor.fetchall() return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] - + def get_charset(self): """Get character_set.""" session = self._db_sessions() - cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();")) + cursor = session.execute( + text( + "SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();" + ) + ) character_set = cursor.fetchone()[0] return character_set - - - def get_show_create_table(self,table_name): + + def get_show_create_table(self, table_name): cur = self.session.execute( text( f""" @@ -119,10 +135,10 @@ class PostgreSQLDatabase(RDBMSDatabase): create_table_query = f"CREATE TABLE {table_name} (\n" for row in rows: create_table_query += f" {row[0]} {row[1]},\n" - create_table_query = create_table_query.rstrip(',\n') + "\n)" + create_table_query = create_table_query.rstrip(",\n") + "\n)" return create_table_query - + def get_table_comments(self, db_name=None): tablses = self.table_simple_info() comments = [] @@ -131,15 +147,13 @@ class PostgreSQLDatabase(RDBMSDatabase): table_comment = self.get_show_create_table(table_name) comments.append((table_name, table_comment)) return comments - + def get_database_list(self): session = self._db_sessions() cursor = session.execute(text("SELECT datname FROM pg_database;")) results = cursor.fetchall() return [ - d[0] - for d in results - if d[0] not in ["template0", "template1", "postgres"] + d[0] for d in results if d[0] not in ["template0", "template1", "postgres"] ] def get_database_names(self): @@ -147,11 +161,9 @@ class PostgreSQLDatabase(RDBMSDatabase): cursor = session.execute(text("SELECT datname FROM pg_database;")) results = cursor.fetchall() return [ - d[0] - for d in results - if d[0] not in ["template0", "template1", "postgres"] + d[0] for d in results if d[0] not in ["template0", "template1", "postgres"] ] - + def get_current_db_name(self) -> str: return self.session.execute(text("SELECT current_database()")).scalar() @@ -176,7 +188,7 @@ class PostgreSQLDatabase(RDBMSDatabase): results = cursor.fetchall() return results - def get_fields(self, table_name, schema_name='public'): + def get_fields(self, table_name, schema_name="public"): """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute( @@ -193,10 +205,13 @@ class PostgreSQLDatabase(RDBMSDatabase): fields = cursor.fetchall() return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] - def get_indexes(self, table_name): """Get table indexes about specified table.""" session = self._db_sessions() - cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'")) + cursor = session.execute( + text( + f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'" + ) + ) indexes = cursor.fetchall() return [(index[0], index[1]) for index in indexes] diff --git a/setup.py b/setup.py index 3521c8d8c..7a3b0b5b8 100644 --- a/setup.py +++ b/setup.py @@ -363,7 +363,7 @@ def all_datasource_requires(): """ pip install "db-gpt[datasource]" """ - setup_spec.extras["datasource"] = ["pymssql", "pymysql","psycopg2"] + setup_spec.extras["datasource"] = ["pymssql", "pymysql", "psycopg2"] def openai_requires():