mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-26 11:59:32 +00:00
style: fix code style with black
This commit is contained in:
parent
6a03c24ee7
commit
14dc5dac04
@ -5,9 +5,9 @@ from pilot.connections.rdbms.base import RDBMSDatabase
|
|||||||
|
|
||||||
|
|
||||||
class PostgreSQLDatabase(RDBMSDatabase):
|
class PostgreSQLDatabase(RDBMSDatabase):
|
||||||
driver = 'postgresql+psycopg2'
|
driver = "postgresql+psycopg2"
|
||||||
db_type = "postgresql"
|
db_type = "postgresql"
|
||||||
db_dialect = 'postgresql'
|
db_dialect = "postgresql"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_uri_db(
|
def from_uri_db(
|
||||||
@ -37,10 +37,14 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
|
|
||||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||||
table_results = self.session.execute(
|
table_results = self.session.execute(
|
||||||
text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
|
text(
|
||||||
|
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
view_results = self.session.execute(
|
view_results = self.session.execute(
|
||||||
text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
|
text(
|
||||||
|
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
table_results = set(row[0] for row in table_results)
|
table_results = set(row[0] for row in table_results)
|
||||||
view_results = set(row[0] for row in view_results)
|
view_results = set(row[0] for row in view_results)
|
||||||
@ -48,13 +52,16 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
self._metadata.reflect(bind=self._engine)
|
self._metadata.reflect(bind=self._engine)
|
||||||
return self._all_tables
|
return self._all_tables
|
||||||
|
|
||||||
|
|
||||||
def get_grants(self):
|
def get_grants(self):
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
cursor = session.execute(text(f"""
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
SELECT DISTINCT grantee, privilege_type
|
SELECT DISTINCT grantee, privilege_type
|
||||||
FROM information_schema.role_table_grants
|
FROM information_schema.role_table_grants
|
||||||
WHERE grantee = CURRENT_USER;"""))
|
WHERE grantee = CURRENT_USER;"""
|
||||||
|
)
|
||||||
|
)
|
||||||
grants = cursor.fetchall()
|
grants = cursor.fetchall()
|
||||||
return grants
|
return grants
|
||||||
|
|
||||||
@ -62,7 +69,11 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
"""Get collation."""
|
"""Get collation."""
|
||||||
try:
|
try:
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"))
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
|
||||||
|
)
|
||||||
|
)
|
||||||
collation = cursor.fetchone()[0]
|
collation = cursor.fetchone()[0]
|
||||||
return collation
|
return collation
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -72,7 +83,9 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
def get_users(self):
|
def get_users(self):
|
||||||
"""Get user info."""
|
"""Get user info."""
|
||||||
try:
|
try:
|
||||||
cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';"))
|
cursor = self.session.execute(
|
||||||
|
text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")
|
||||||
|
)
|
||||||
users = cursor.fetchall()
|
users = cursor.fetchall()
|
||||||
return [user[0] for user in users]
|
return [user[0] for user in users]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -95,12 +108,15 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
def get_charset(self):
|
def get_charset(self):
|
||||||
"""Get character_set."""
|
"""Get character_set."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"))
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
|
||||||
|
)
|
||||||
|
)
|
||||||
character_set = cursor.fetchone()[0]
|
character_set = cursor.fetchone()[0]
|
||||||
return character_set
|
return character_set
|
||||||
|
|
||||||
|
def get_show_create_table(self, table_name):
|
||||||
def get_show_create_table(self,table_name):
|
|
||||||
cur = self.session.execute(
|
cur = self.session.execute(
|
||||||
text(
|
text(
|
||||||
f"""
|
f"""
|
||||||
@ -119,7 +135,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
create_table_query = f"CREATE TABLE {table_name} (\n"
|
create_table_query = f"CREATE TABLE {table_name} (\n"
|
||||||
for row in rows:
|
for row in rows:
|
||||||
create_table_query += f" {row[0]} {row[1]},\n"
|
create_table_query += f" {row[0]} {row[1]},\n"
|
||||||
create_table_query = create_table_query.rstrip(',\n') + "\n)"
|
create_table_query = create_table_query.rstrip(",\n") + "\n)"
|
||||||
|
|
||||||
return create_table_query
|
return create_table_query
|
||||||
|
|
||||||
@ -137,9 +153,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return [
|
return [
|
||||||
d[0]
|
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||||
for d in results
|
|
||||||
if d[0] not in ["template0", "template1", "postgres"]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_database_names(self):
|
def get_database_names(self):
|
||||||
@ -147,9 +161,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return [
|
return [
|
||||||
d[0]
|
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
||||||
for d in results
|
|
||||||
if d[0] not in ["template0", "template1", "postgres"]
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_current_db_name(self) -> str:
|
def get_current_db_name(self) -> str:
|
||||||
@ -176,7 +188,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
results = cursor.fetchall()
|
results = cursor.fetchall()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def get_fields(self, table_name, schema_name='public'):
|
def get_fields(self, table_name, schema_name="public"):
|
||||||
"""Get column fields about specified table."""
|
"""Get column fields about specified table."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
cursor = session.execute(
|
cursor = session.execute(
|
||||||
@ -193,10 +205,13 @@ class PostgreSQLDatabase(RDBMSDatabase):
|
|||||||
fields = cursor.fetchall()
|
fields = cursor.fetchall()
|
||||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||||
|
|
||||||
|
|
||||||
def get_indexes(self, table_name):
|
def get_indexes(self, table_name):
|
||||||
"""Get table indexes about specified table."""
|
"""Get table indexes about specified table."""
|
||||||
session = self._db_sessions()
|
session = self._db_sessions()
|
||||||
cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"))
|
cursor = session.execute(
|
||||||
|
text(
|
||||||
|
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
|
||||||
|
)
|
||||||
|
)
|
||||||
indexes = cursor.fetchall()
|
indexes = cursor.fetchall()
|
||||||
return [(index[0], index[1]) for index in indexes]
|
return [(index[0], index[1]) for index in indexes]
|
||||||
|
2
setup.py
2
setup.py
@ -363,7 +363,7 @@ def all_datasource_requires():
|
|||||||
"""
|
"""
|
||||||
pip install "db-gpt[datasource]"
|
pip install "db-gpt[datasource]"
|
||||||
"""
|
"""
|
||||||
setup_spec.extras["datasource"] = ["pymssql", "pymysql","psycopg2"]
|
setup_spec.extras["datasource"] = ["pymssql", "pymysql", "psycopg2"]
|
||||||
|
|
||||||
|
|
||||||
def openai_requires():
|
def openai_requires():
|
||||||
|
Loading…
Reference in New Issue
Block a user