DB-GPT/dbgpt/datasource/rdbms/dialect/vertica/base.py

180 lines
5.8 KiB
Python

"""Base class for Vertica dialect."""
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)
import logging
import re
from typing import Any, Optional
from sqlalchemy import sql
from sqlalchemy.engine import default, reflection
logger: logging.Logger = logging.getLogger(__name__)
class VerticaInspector(reflection.Inspector):
"""Reflection inspector for Vertica."""
dialect: VerticaDialect
def get_all_columns(self, table, schema: Optional[str] = None, **kw: Any):
r"""Return all table columns names within a particular schema."""
return self.dialect.get_all_columns(
self.bind, table, schema, info_cache=self.info_cache, **kw
)
def get_table_comment(self, table_name: str, schema: Optional[str] = None, **kw):
"""Return comment of a table in a schema."""
return self.dialect.get_table_comment(
self.bind, table_name, schema, info_cache=self.info_cache, **kw
)
def get_view_columns(
self, view: Optional[str] = None, schema: Optional[str] = None, **kw: Any
):
r"""Return all view columns names within a particular schema."""
return self.dialect.get_view_columns(
self.bind, view, schema, info_cache=self.info_cache, **kw
)
def get_view_comment(
self, view: Optional[str] = None, schema: Optional[str] = None, **kw
):
r"""Return view comments within a particular schema."""
return self.dialect.get_view_comment(
self.bind, view, schema, info_cache=self.info_cache, **kw
)
class VerticaDialect(default.DefaultDialect):
"""Vertica dialect."""
name = "vertica"
inspector = VerticaInspector
def __init__(self, json_serializer=None, json_deserializer=None, **kwargs):
"""Init object."""
default.DefaultDialect.__init__(self, **kwargs)
self._json_deserializer = json_deserializer
self._json_serializer = json_serializer
def initialize(self, connection):
"""Init dialect."""
super().initialize(connection)
def _get_default_schema_name(self, connection):
return connection.scalar(sql.text("SELECT current_schema()"))
def _get_server_version_info(self, connection):
v = connection.scalar(sql.text("SELECT version()"))
m = re.match(r".*Vertica Analytic Database v(\d+)\.(\d+)\.(\d)+.*", v)
if not m:
raise AssertionError(
"Could not determine version from string '%(ver)s'" % {"ver": v}
)
return tuple([int(x) for x in m.group(1, 2, 3) if x is not None])
def create_connect_args(self, url):
"""Create args of connection."""
opts = url.translate_connect_args(username="user")
opts.update(url.query)
return [], opts
def has_table(self, connection, table_name, schema=None):
"""Check availability of a table."""
return False
def has_sequence(self, connection, sequence_name, schema=None):
"""Check availability of a sequence."""
return False
def has_type(self, connection, type_name):
"""Check availability of a type."""
return False
def get_schema_names(self, connection, **kw):
"""Return names of all schemas."""
return []
def get_table_comment(self, connection, table_name, schema=None, **kw):
"""Return comment of a table in a schema."""
return {"text": table_name}
def get_table_names(self, connection, schema=None, **kw):
"""Get names of tables in a schema."""
return []
def get_temp_table_names(self, connection, schema=None, **kw):
"""Get names of temp tables in a schema."""
return []
def get_view_names(self, connection, schema=None, **kw):
"""Get names of views in a schema."""
return []
def get_view_definition(self, connection, view_name, schema=None, **kw):
"""Get definition of views in a schema."""
return view_name
def get_temp_view_names(self, connection, schema=None, **kw):
"""Get names of temp views in a schema."""
return []
def get_unique_constraints(self, connection, table_name, schema=None, **kw):
"""Get unique constrains of a table in a schema."""
return []
def get_check_constraints(self, connection, table_name, schema=None, **kw):
"""Get checks of a table in a schema."""
return []
def normalize_name(self, name):
"""Normalize name."""
name = name and name.rstrip()
if name is None:
return None
return name.lower()
def denormalize_name(self, name):
"""Denormalize name."""
return name
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
"""Get poreignn keys of a table in a schema."""
return []
def get_indexes(self, connection, table_name, schema, **kw):
"""Get indexes of a table in a schema."""
return []
def visit_create_index(self, create):
"""Disable index creation since that's not a thing in Vertica."""
return None
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
"""Get primary keye of a table in a schema."""
return None
def get_all_columns(self, connection, table, schema=None, **kw):
"""Get all columns of a table in a schema."""
return []
def get_columns(self, connection, table_name, schema=None, **kw):
"""Get all columns of a table in a schema."""
return self.get_all_columns(connection, table_name, schema)
def get_view_columns(self, connection, view, schema=None, **kw):
"""Get columns of views in a schema."""
return []
def get_view_comment(self, connection, view, schema=None, **kw):
"""Get comment of view."""
return {"text": view}