mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 19:39:35 +00:00
style: fix code style with black
This commit is contained in:
parent
6a03c24ee7
commit
14dc5dac04
@ -5,9 +5,9 @@ from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class PostgreSQLDatabase(RDBMSDatabase):
|
||||
driver = 'postgresql+psycopg2'
|
||||
driver = "postgresql+psycopg2"
|
||||
db_type = "postgresql"
|
||||
db_dialect = 'postgresql'
|
||||
db_dialect = "postgresql"
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
@ -34,13 +34,17 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
+ db_name
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.session.execute(
|
||||
text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
|
||||
text(
|
||||
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
)
|
||||
)
|
||||
view_results = self.session.execute(
|
||||
text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
|
||||
text(
|
||||
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||
)
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
@ -48,31 +52,40 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
|
||||
def get_grants(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"""
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT DISTINCT grantee, privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = CURRENT_USER;"""))
|
||||
WHERE grantee = CURRENT_USER;"""
|
||||
)
|
||||
)
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
try:
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"))
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
|
||||
)
|
||||
)
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
except Exception as e:
|
||||
print("postgresql get collation error: ", e)
|
||||
return None
|
||||
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
try:
|
||||
cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';"))
|
||||
cursor = self.session.execute(
|
||||
text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")
|
||||
)
|
||||
users = cursor.fetchall()
|
||||
return [user[0] for user in users]
|
||||
except Exception as e:
|
||||
@ -91,16 +104,19 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"))
|
||||
cursor = session.execute(
|
||||
text(
|
||||
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
|
||||
)
|
||||
)
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
|
||||
def get_show_create_table(self,table_name):
|
||||
|
||||
def get_show_create_table(self, table_name):
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
@ -119,10 +135,10 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
create_table_query = f"CREATE TABLE {table_name} (\n"
|
||||
for row in rows:
|
||||
create_table_query += f" {row[0]} {row[1]},\n"
|
||||
create_table_query = create_table_query.rstrip(',\n') + "\n)"
|
||||
create_table_query = create_table_query.rstrip(",\n") + "\n)"
|
||||
|
||||
return create_table_query
|
||||
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
tablses = self.table_simple_info()
|
||||
comments = []
|
||||
@ -131,15 +147,13 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
table_comment = self.get_show_create_table(table_name)
|
||||
comments.append((table_name, table_comment))
|
||||
return comments
|
||||
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["template0", "template1", "postgres"]
|
||||
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
def get_database_names(self):
|
||||
@ -147,11 +161,9 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["template0", "template1", "postgres"]
|
||||
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.session.execute(text("SELECT current_database()")).scalar()
|
||||
|
||||
@ -176,7 +188,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def get_fields(self, table_name, schema_name='public'):
|
||||
def get_fields(self, table_name, schema_name="public"):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
@ -193,10 +205,13 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"))
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
|
||||
)
|
||||
)
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[0], index[1]) for index in indexes]
|
||||
|
Loading…
Reference in New Issue
Block a user