mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-26 21:37:40 +00:00
226 lines
8.0 KiB
Python
226 lines
8.0 KiB
Python
"""PostgreSQL connector."""
|
|
import logging
|
|
from typing import Any, Iterable, List, Optional, Tuple, cast
|
|
from urllib.parse import quote
|
|
from urllib.parse import quote_plus as urlquote
|
|
|
|
from sqlalchemy import text
|
|
|
|
from .base import RDBMSConnector
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class PostgreSQLConnector(RDBMSConnector):
|
|
"""PostgreSQL connector."""
|
|
|
|
driver = "postgresql+psycopg2"
|
|
db_type = "postgresql"
|
|
db_dialect = "postgresql"
|
|
|
|
@classmethod
|
|
def from_uri_db(
|
|
cls,
|
|
host: str,
|
|
port: int,
|
|
user: str,
|
|
pwd: str,
|
|
db_name: str,
|
|
engine_args: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> "PostgreSQLConnector":
|
|
"""Create a new PostgreSQLConnector from host, port, user, pwd, db_name."""
|
|
db_url: str = (
|
|
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
|
|
)
|
|
return cast(PostgreSQLConnector, cls.from_uri(db_url, engine_args, **kwargs))
|
|
|
|
def _sync_tables_from_db(self) -> Iterable[str]:
|
|
table_results = self.session.execute(
|
|
text(
|
|
"SELECT tablename FROM pg_catalog.pg_tables WHERE "
|
|
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
|
)
|
|
)
|
|
view_results = self.session.execute(
|
|
text(
|
|
"SELECT viewname FROM pg_catalog.pg_views WHERE "
|
|
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
|
|
)
|
|
)
|
|
table_results = set(row[0] for row in table_results) # noqa: C401
|
|
view_results = set(row[0] for row in view_results) # noqa: C401
|
|
self._all_tables = table_results.union(view_results)
|
|
self._metadata.reflect(bind=self._engine)
|
|
return self._all_tables
|
|
|
|
def get_grants(self):
|
|
"""Get grants."""
|
|
session = self._db_sessions()
|
|
cursor = session.execute(
|
|
text(
|
|
"""
|
|
SELECT DISTINCT grantee, privilege_type
|
|
FROM information_schema.role_table_grants
|
|
WHERE grantee = CURRENT_USER;"""
|
|
)
|
|
)
|
|
grants = cursor.fetchall()
|
|
return grants
|
|
|
|
def get_collation(self):
|
|
"""Get collation."""
|
|
try:
|
|
session = self._db_sessions()
|
|
cursor = session.execute(
|
|
text(
|
|
"SELECT datcollate AS collation FROM pg_database WHERE "
|
|
"datname = current_database();"
|
|
)
|
|
)
|
|
collation = cursor.fetchone()[0]
|
|
return collation
|
|
except Exception as e:
|
|
logger.warning(f"postgresql get collation error: {str(e)}")
|
|
return None
|
|
|
|
def get_users(self):
|
|
"""Get user info."""
|
|
try:
|
|
cursor = self.session.execute(
|
|
text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")
|
|
)
|
|
users = cursor.fetchall()
|
|
return [user[0] for user in users]
|
|
except Exception as e:
|
|
logger.warning(f"postgresql get users error: {str(e)}")
|
|
return []
|
|
|
|
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
|
|
"""Get column fields about specified table."""
|
|
session = self._db_sessions()
|
|
cursor = session.execute(
|
|
text(
|
|
"SELECT column_name, data_type, column_default, is_nullable, "
|
|
"column_name as column_comment \
|
|
FROM information_schema.columns WHERE table_name = :table_name",
|
|
),
|
|
{"table_name": table_name},
|
|
)
|
|
fields = cursor.fetchall()
|
|
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
|
|
|
def get_charset(self):
|
|
"""Get character_set."""
|
|
session = self._db_sessions()
|
|
cursor = session.execute(
|
|
text(
|
|
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE "
|
|
"datname = current_database();"
|
|
)
|
|
)
|
|
character_set = cursor.fetchone()[0]
|
|
return character_set
|
|
|
|
def get_show_create_table(self, table_name: str):
|
|
"""Return show create table."""
|
|
cur = self.session.execute(
|
|
text(
|
|
f"""
|
|
SELECT a.attname as column_name,
|
|
pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
|
|
FROM pg_catalog.pg_attribute a
|
|
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
|
|
SELECT max(a.attnum)
|
|
FROM pg_catalog.pg_attribute a
|
|
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class
|
|
WHERE relname='{table_name}')
|
|
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class
|
|
WHERE relname='{table_name}')
|
|
"""
|
|
)
|
|
)
|
|
rows = cur.fetchall()
|
|
|
|
create_table_query = f"CREATE TABLE {table_name} (\n"
|
|
for row in rows:
|
|
create_table_query += f" {row[0]} {row[1]},\n"
|
|
create_table_query = create_table_query.rstrip(",\n") + "\n)"
|
|
|
|
return create_table_query
|
|
|
|
def get_table_comments(self, db_name=None):
|
|
"""Get table comments."""
|
|
tablses = self.table_simple_info()
|
|
comments = []
|
|
for table in tablses:
|
|
table_name = table[0]
|
|
table_comment = self.get_show_create_table(table_name)
|
|
comments.append((table_name, table_comment))
|
|
return comments
|
|
|
|
def get_database_names(self):
|
|
"""Get database names."""
|
|
session = self._db_sessions()
|
|
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
|
results = cursor.fetchall()
|
|
return [
|
|
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
|
|
]
|
|
|
|
def get_current_db_name(self) -> str:
|
|
"""Get current database name."""
|
|
return self.session.execute(text("SELECT current_database()")).scalar()
|
|
|
|
def table_simple_info(self):
|
|
"""Get table simple info."""
|
|
_sql = """
|
|
SELECT table_name, string_agg(column_name, ', ') AS schema_info
|
|
FROM (
|
|
SELECT c.relname AS table_name, a.attname AS column_name
|
|
FROM pg_catalog.pg_class c
|
|
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
|
JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
|
|
WHERE c.relkind = 'r'
|
|
AND a.attnum > 0
|
|
AND NOT a.attisdropped
|
|
AND n.nspname NOT LIKE 'pg_%'
|
|
AND n.nspname != 'information_schema'
|
|
ORDER BY c.relname, a.attnum
|
|
) sub
|
|
GROUP BY table_name;
|
|
"""
|
|
cursor = self.session.execute(text(_sql))
|
|
results = cursor.fetchall()
|
|
return results
|
|
|
|
def get_fields_wit_schema(self, table_name, schema_name="public"):
|
|
"""Get column fields about specified table."""
|
|
session = self._db_sessions()
|
|
cursor = session.execute(
|
|
text(
|
|
f"""
|
|
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable,
|
|
d.description FROM information_schema.columns c
|
|
LEFT JOIN pg_catalog.pg_description d
|
|
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid
|
|
AND c.ordinal_position = d.objsubid
|
|
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
|
|
"""
|
|
)
|
|
)
|
|
fields = cursor.fetchall()
|
|
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
|
|
|
def get_indexes(self, table_name):
|
|
"""Get table indexes about specified table."""
|
|
session = self._db_sessions()
|
|
cursor = session.execute(
|
|
text(
|
|
f"SELECT indexname, indexdef FROM pg_indexes WHERE "
|
|
f"tablename = '{table_name}'"
|
|
)
|
|
)
|
|
indexes = cursor.fetchall()
|
|
return [(index[0], index[1]) for index in indexes]
|