refactor: Refactor datasource module (#1309)

This commit is contained in:
Fangyin Cheng
2024-03-18 18:06:40 +08:00
committed by GitHub
parent 84bedee306
commit 4970c9f813
108 changed files with 1194 additions and 1066 deletions

View File

@@ -1,13 +1,19 @@
from typing import Any, Iterable, List, Optional, Tuple
"""PostgreSQL connector."""
import logging
from typing import Any, Iterable, List, Optional, Tuple, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from dbgpt.datasource.rdbms.base import RDBMSDatabase
from .base import RDBMSConnector
logger = logging.getLogger(__name__)
class PostgreSQLDatabase(RDBMSDatabase):
class PostgreSQLConnector(RDBMSConnector):
"""PostgreSQL connector."""
driver = "postgresql+psycopg2"
db_type = "postgresql"
db_dialect = "postgresql"
@@ -22,34 +28,38 @@ class PostgreSQLDatabase(RDBMSDatabase):
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
) -> "PostgreSQLConnector":
"""Create a new PostgreSQLConnector from host, port, user, pwd, db_name."""
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cls.from_uri(db_url, engine_args, **kwargs)
return cast(PostgreSQLConnector, cls.from_uri(db_url, engine_args, **kwargs))
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
text(
"SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
"SELECT tablename FROM pg_catalog.pg_tables WHERE "
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
)
)
view_results = self.session.execute(
text(
"SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
"SELECT viewname FROM pg_catalog.pg_views WHERE "
"schemaname != 'pg_catalog' AND schemaname != 'information_schema'"
)
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
table_results = set(row[0] for row in table_results) # noqa: C401
view_results = set(row[0] for row in view_results) # noqa: C401
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
"""Get grants."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
"""
SELECT DISTINCT grantee, privilege_type
FROM information_schema.role_table_grants
WHERE grantee = CURRENT_USER;"""
@@ -64,13 +74,14 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
"SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"
"SELECT datcollate AS collation FROM pg_database WHERE "
"datname = current_database();"
)
)
collation = cursor.fetchone()[0]
return collation
except Exception as e:
print("postgresql get collation error: ", e)
logger.warning(f"postgresql get collation error: {str(e)}")
return None
def get_users(self):
@@ -82,7 +93,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
users = cursor.fetchall()
return [user[0] for user in users]
except Exception as e:
print("postgresql get users error: ", e)
logger.warning(f"postgresql get users error: {str(e)}")
return []
def get_fields(self, table_name) -> List[Tuple]:
@@ -90,7 +101,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
"SELECT column_name, data_type, column_default, is_nullable, "
"column_name as column_comment \
FROM information_schema.columns WHERE table_name = :table_name",
),
{"table_name": table_name},
@@ -103,23 +115,28 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"
"SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE "
"datname = current_database();"
)
)
character_set = cursor.fetchone()[0]
return character_set
def get_show_create_table(self, table_name):
def get_show_create_table(self, table_name: str):
"""Return show create table."""
cur = self.session.execute(
text(
f"""
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
SELECT a.attname as column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
FROM pg_catalog.pg_attribute a
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
SELECT max(a.attnum)
FROM pg_catalog.pg_attribute a
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class
WHERE relname='{table_name}')
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class
WHERE relname='{table_name}')
"""
)
)
@@ -133,6 +150,7 @@ class PostgreSQLDatabase(RDBMSDatabase):
return create_table_query
def get_table_comments(self, db_name=None):
"""Get table comments."""
tablses = self.table_simple_info()
comments = []
for table in tablses:
@@ -141,15 +159,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
comments.append((table_name, table_comment))
return comments
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
return [
d[0] for d in results if d[0] not in ["template0", "template1", "postgres"]
]
def get_database_names(self):
"""Get database names."""
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
@@ -158,10 +169,12 @@ class PostgreSQLDatabase(RDBMSDatabase):
]
def get_current_db_name(self) -> str:
"""Get current database name."""
return self.session.execute(text("SELECT current_database()")).scalar()
def table_simple_info(self):
_sql = f"""
"""Get table simple info."""
_sql = """
SELECT table_name, string_agg(column_name, ', ') AS schema_info
FROM (
SELECT c.relname AS table_name, a.attname AS column_name
@@ -181,17 +194,18 @@ class PostgreSQLDatabase(RDBMSDatabase):
results = cursor.fetchall()
return results
def get_fields(self, table_name, schema_name="public"):
def get_fields_wit_schema(self, table_name, schema_name="public"):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
FROM information_schema.columns c
LEFT JOIN pg_catalog.pg_description d
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable,
d.description FROM information_schema.columns c
LEFT JOIN pg_catalog.pg_description d
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid
AND c.ordinal_position = d.objsubid
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
"""
)
)
@@ -203,7 +217,8 @@ class PostgreSQLDatabase(RDBMSDatabase):
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"
f"SELECT indexname, indexdef FROM pg_indexes WHERE "
f"tablename = '{table_name}'"
)
)
indexes = cursor.fetchall()