feat: Add support for Vertica analytical database (#1538)

This commit is contained in:
DQ
2024-05-20 09:45:40 +08:00
committed by GitHub
parent 6fb3d33bf4
commit 8d8411fcd3
16 changed files with 658 additions and 5 deletions

View File

@@ -0,0 +1,256 @@
"""Vertica connector."""
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from sqlalchemy.dialects import registry
from .base import RDBMSConnector
logger = logging.getLogger(__name__)
registry.register(
"vertica.vertica_python",
"dbgpt.datasource.rdbms.dialect.vertica.dialect_vertica_python",
"VerticaDialect",
)
class VerticaConnector(RDBMSConnector):
"""Vertica connector."""
driver = "vertica+vertica_python"
db_type = "vertica"
db_dialect = "vertica"
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> "VerticaConnector":
"""Create a new VerticaConnector from host, port, user, pwd, db_name."""
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cast(VerticaConnector, cls.from_uri(db_url, engine_args, **kwargs))
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
# inject instruction to prompt according to {dialect} in prompt template.
return "Vertica sql, \
correct postgresql sql is the another option \
if you don't know much about Vertica. \
尤其要注意,表名称前面一定要带上模式名称!! \
Note the most important requirement is that \
table name should keep its schema name in "
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
text(
"""
SELECT table_schema||'.'||table_name
FROM v_catalog.tables
WHERE table_schema NOT LIKE 'v\_%'
UNION
SELECT table_schema||'.'||table_name
FROM v_catalog.views
WHERE table_schema NOT LIKE 'v\_%';
"""
)
)
self._all_tables = {row[0] for row in table_results}
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
"""Get collation."""
return None
def get_users(self):
"""Get user info."""
try:
cursor = self.session.execute(text("SELECT name FROM v_internal.vs_users;"))
users = cursor.fetchall()
return [user[0] for user in users]
except Exception as e:
logger.warning(f"vertica get users error: {str(e)}")
return []
def get_fields(self, table_name) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT column_name, data_type, column_default, is_nullable,
nvl(comment, column_name) as column_comment
FROM v_catalog.columns c
LEFT JOIN v_internal.vs_sub_comments s ON c.table_id = s.objectoid
AND c.column_name = s.childobject
WHERE table_schema||'.'||table_name = '{table_name}';
"""
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns about specified table.
Args:
table_name (str): table name
Returns:
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT c.column_name, data_type, column_default
, (p.column_name IS NOT NULL) is_in_primary_key
, nvl(comment, c.column_name) as column_comment
FROM v_catalog.columns c
LEFT JOIN v_internal.vs_sub_comments s ON c.table_id = s.objectoid
AND c.column_name = s.childobject
LEFT JOIN v_catalog.primary_keys p ON c.table_schema = p.table_schema
AND c.table_name = p.table_name
AND c.column_name = p.column_name
WHERE c.table_schema||'.'||c.table_name = '{table_name}';
"""
)
)
fields = cursor.fetchall()
return [
{
"name": field[0],
"type": field[1],
"default_expression": field[2],
"is_in_primary_key": field[3],
"comment": field[4],
}
for field in fields
]
def get_charset(self):
"""Get character_set."""
return "utf-8"
def get_show_create_table(self, table_name: str):
"""Return show create table."""
cur = self.session.execute(
text(
f"""
SELECT column_name, data_type
FROM v_catalog.columns
WHERE table_schema||'.'||table_name = '{table_name}';
"""
)
)
rows = cur.fetchall()
create_table_query = f"CREATE TABLE {table_name} (\n"
for row in rows:
create_table_query += f" {row[0]} {row[1]},\n"
create_table_query = create_table_query.rstrip(",\n") + "\n)"
return create_table_query
def get_table_comments(self, db_name=None):
"""Return table comments."""
cursor = self.session.execute(
text(
f"""
SELECT table_schema||'.'||table_name
, nvl(comment, table_name) as column_comment
FROM v_catalog.tables t
LEFT JOIN v_internal.vs_comments c ON t.table_id = c.objectoid
WHERE table_schema = '{db_name}'
"""
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def get_table_comment(self, table_name: str) -> Dict:
"""Get table comments.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
cursor = self.session.execute(
text(
f"""
SELECT nvl(comment, table_name) as column_comment
FROM v_catalog.tables t
LEFT JOIN v_internal.vs_comments c ON t.table_id = c.objectoid
WHERE table_schema||'.'||table_nam e= '{table_name}'
"""
)
)
return {"text": cursor.scalar()}
def get_column_comments(self, db_name: str, table_name: str):
"""Return column comments."""
cursor = self.session.execute(
text(
f"""
SELECT column_name, nvl(comment, column_name) as column_comment
FROM v_catalog.columns c
LEFT JOIN v_internal.vs_sub_comments s ON c.table_id = s.objectoid
AND c.column_name = s.childobject
WHERE table_schema = '{db_name}' AND table_name = '{table_name}'
"""
)
)
column_comments = cursor.fetchall()
return [
(column_comment[0], column_comment[1]) for column_comment in column_comments
]
def get_database_names(self):
"""Get database names."""
session = self._db_sessions()
cursor = session.execute(text("SELECT schema_name FROM v_catalog.schemata;"))
results = cursor.fetchall()
return [d[0] for d in results if not d[0].startswith("v_")]
def get_current_db_name(self) -> str:
"""Get current database name."""
return self.session.execute(text("SELECT current_schema()")).scalar()
def table_simple_info(self):
"""Get table simple info."""
_sql = """
SELECT table_schema||'.'||table_name
, listagg(column_name using parameters max_length=65000)
FROM v_catalog.columns
WHERE table_schema NOT LIKE 'v\_%'
GROUP BY 1;
"""
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
return results
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return []

View File

@@ -0,0 +1,179 @@
"""Base class for Vertica dialect."""
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)
import logging
import re
from typing import Any, Optional
from sqlalchemy import sql
from sqlalchemy.engine import default, reflection
logger: logging.Logger = logging.getLogger(__name__)
class VerticaInspector(reflection.Inspector):
"""Reflection inspector for Vertica."""
dialect: VerticaDialect
def get_all_columns(self, table, schema: Optional[str] = None, **kw: Any):
r"""Return all table columns names within a particular schema."""
return self.dialect.get_all_columns(
self.bind, table, schema, info_cache=self.info_cache, **kw
)
def get_table_comment(self, table_name: str, schema: Optional[str] = None, **kw):
"""Return comment of a table in a schema."""
return self.dialect.get_table_comment(
self.bind, table_name, schema, info_cache=self.info_cache, **kw
)
def get_view_columns(
self, view: Optional[str] = None, schema: Optional[str] = None, **kw: Any
):
r"""Return all view columns names within a particular schema."""
return self.dialect.get_view_columns(
self.bind, view, schema, info_cache=self.info_cache, **kw
)
def get_view_comment(
self, view: Optional[str] = None, schema: Optional[str] = None, **kw
):
r"""Return view comments within a particular schema."""
return self.dialect.get_view_comment(
self.bind, view, schema, info_cache=self.info_cache, **kw
)
class VerticaDialect(default.DefaultDialect):
"""Vertica dialect."""
name = "vertica"
inspector = VerticaInspector
def __init__(self, json_serializer=None, json_deserializer=None, **kwargs):
"""Init object."""
default.DefaultDialect.__init__(self, **kwargs)
self._json_deserializer = json_deserializer
self._json_serializer = json_serializer
def initialize(self, connection):
"""Init dialect."""
super().initialize(connection)
def _get_default_schema_name(self, connection):
return connection.scalar(sql.text("SELECT current_schema()"))
def _get_server_version_info(self, connection):
v = connection.scalar(sql.text("SELECT version()"))
m = re.match(r".*Vertica Analytic Database v(\d+)\.(\d+)\.(\d)+.*", v)
if not m:
raise AssertionError(
"Could not determine version from string '%(ver)s'" % {"ver": v}
)
return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
def create_connect_args(self, url):
"""Create args of connection."""
opts = url.translate_connect_args(username="user")
opts.update(url.query)
return [], opts
def has_table(self, connection, table_name, schema=None):
"""Check availability of a table."""
return False
def has_sequence(self, connection, sequence_name, schema=None):
"""Check availability of a sequence."""
return False
def has_type(self, connection, type_name):
"""Check availability of a type."""
return False
def get_schema_names(self, connection, **kw):
"""Return names of all schemas."""
return []
def get_table_comment(self, connection, table_name, schema=None, **kw):
"""Return comment of a table in a schema."""
return {"text": table_name}
def get_table_names(self, connection, schema=None, **kw):
"""Get names of tables in a schema."""
return []
def get_temp_table_names(self, connection, schema=None, **kw):
"""Get names of temp tables in a schema."""
return []
def get_view_names(self, connection, schema=None, **kw):
"""Get names of views in a schema."""
return []
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""Get definition of views in a schema."""
return view_name
def get_temp_view_names(self, connection, schema=None, **kw):
"""Get names of temp views in a schema."""
return []
def get_unique_constraints(self, connection, table_name, schema=None, **kw):
"""Get unique constrains of a table in a schema."""
return []
def get_check_constraints(self, connection, table_name, schema=None, **kw):
"""Get checks of a table in a schema."""
return []
def normalize_name(self, name):
"""Normalize name."""
name = name and name.rstrip()
if name is None:
return None
return name.lower()
def denormalize_name(self, name):
"""Denormalize name."""
return name
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
"""Get poreignn keys of a table in a schema."""
return []
def get_indexes(self, connection, table_name, schema, **kw):
"""Get indexes of a table in a schema."""
return []
def visit_create_index(self, create):
"""Disable index creation since that's not a thing in Vertica."""
return None
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
"""Get primary keye of a table in a schema."""
return None
def get_all_columns(self, connection, table, schema=None, **kw):
"""Get all columns of a table in a schema."""
return []
def get_columns(self, connection, table_name, schema=None, **kw):
"""Get all columns of a table in a schema."""
return self.get_all_columns(connection, table_name, schema)
def get_view_columns(self, connection, view, schema=None, **kw):
"""Get columns of views in a schema."""
return []
def get_view_comment(self, connection, view, schema=None, **kw):
"""Get comment of view."""
return {"text": view}

View File

@@ -0,0 +1,23 @@
"""Vertica dialect."""
from __future__ import absolute_import, division, print_function
from .base import VerticaDialect as BaseVerticaDialect
# noinspection PyAbstractClass, PyClassHasNoInit
class VerticaDialect(BaseVerticaDialect):
"""Vertica dialect class."""
driver = "vertica_python"
# TODO: support SQL caching, for more info see:
# https://docs.sqlalchemy.org/en/14/core/connections.html#caching-for-third-party-dialects
supports_statement_cache = False
# No lastrowid support. TODO support SELECT LAST_INSERT_ID();
postfetch_lastrowid = False
@classmethod
def dbapi(cls):
"""Get Driver."""
vertica_python = __import__("vertica_python")
return vertica_python