DB-GPT/dbgpt/datasource/rdbms/conn_vertica.py
2024-08-05 19:26:39 +08:00

257 lines
8.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Vertica connector."""
import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
from urllib.parse import quote
from urllib.parse import quote_plus as urlquote
from sqlalchemy import text
from sqlalchemy.dialects import registry
from .base import RDBMSConnector
logger = logging.getLogger(__name__)
registry.register(
"vertica.vertica_python",
"dbgpt.datasource.rdbms.dialect.vertica.dialect_vertica_python",
"VerticaDialect",
)
class VerticaConnector(RDBMSConnector):
"""Vertica connector."""
driver = "vertica+vertica_python"
db_type = "vertica"
db_dialect = "vertica"
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> "VerticaConnector":
"""Create a new VerticaConnector from host, port, user, pwd, db_name."""
db_url: str = (
f"{cls.driver}://{quote(user)}:{urlquote(pwd)}@{host}:{str(port)}/{db_name}"
)
return cast(VerticaConnector, cls.from_uri(db_url, engine_args, **kwargs))
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
# inject instruction to prompt according to {dialect} in prompt template.
return "Vertica sql, \
correct postgresql sql is the another option \
if you don't know much about Vertica. \
尤其要注意,表名称前面一定要带上模式名称!! \
Note the most important requirement is that \
table name should keep its schema name in "
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
text(
"""
SELECT table_schema||'.'||table_name
FROM v_catalog.tables
WHERE table_schema NOT LIKE 'v\_%'
UNION
SELECT table_schema||'.'||table_name
FROM v_catalog.views
WHERE table_schema NOT LIKE 'v\_%';
"""
)
)
self._all_tables = {row[0] for row in table_results}
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
"""Get grants."""
return []
def get_collation(self):
"""Get collation."""
return None
def get_users(self):
"""Get user info."""
try:
cursor = self.session.execute(text("SELECT name FROM v_internal.vs_users;"))
users = cursor.fetchall()
return [user[0] for user in users]
except Exception as e:
logger.warning(f"vertica get users error: {str(e)}")
return []
def get_fields(self, table_name, db_name=None) -> List[Tuple]:
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT column_name, data_type, column_default, is_nullable,
nvl(comment, column_name) as column_comment
FROM v_catalog.columns c
LEFT JOIN v_internal.vs_sub_comments s ON c.table_id = s.objectoid
AND c.column_name = s.childobject
WHERE table_schema||'.'||table_name = '{table_name}';
"""
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_columns(self, table_name: str) -> List[Dict]:
"""Get columns about specified table.
Args:
table_name (str): table name
Returns:
columns: List[Dict], which contains name: str, type: str,
default_expression: str, is_in_primary_key: bool, comment: str
eg:[{'name': 'id', 'type': 'int', 'default_expression': '',
'is_in_primary_key': True, 'comment': 'id'}, ...]
"""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT c.column_name, data_type, column_default
, (p.column_name IS NOT NULL) is_in_primary_key
, nvl(comment, c.column_name) as column_comment
FROM v_catalog.columns c
LEFT JOIN v_internal.vs_sub_comments s ON c.table_id = s.objectoid
AND c.column_name = s.childobject
LEFT JOIN v_catalog.primary_keys p ON c.table_schema = p.table_schema
AND c.table_name = p.table_name
AND c.column_name = p.column_name
WHERE c.table_schema||'.'||c.table_name = '{table_name}';
"""
)
)
fields = cursor.fetchall()
return [
{
"name": field[0],
"type": field[1],
"default_expression": field[2],
"is_in_primary_key": field[3],
"comment": field[4],
}
for field in fields
]
def get_charset(self):
"""Get character_set."""
return "utf-8"
def get_show_create_table(self, table_name: str):
"""Return show create table."""
cur = self.session.execute(
text(
f"""
SELECT column_name, data_type
FROM v_catalog.columns
WHERE table_schema||'.'||table_name = '{table_name}';
"""
)
)
rows = cur.fetchall()
create_table_query = f"CREATE TABLE {table_name} (\n"
for row in rows:
create_table_query += f" {row[0]} {row[1]},\n"
create_table_query = create_table_query.rstrip(",\n") + "\n)"
return create_table_query
def get_table_comments(self, db_name=None):
"""Return table comments."""
cursor = self.session.execute(
text(
f"""
SELECT table_schema||'.'||table_name
, nvl(comment, table_name) as column_comment
FROM v_catalog.tables t
LEFT JOIN v_internal.vs_comments c ON t.table_id = c.objectoid
WHERE table_schema = '{db_name}'
"""
)
)
table_comments = cursor.fetchall()
return [
(table_comment[0], table_comment[1]) for table_comment in table_comments
]
def get_table_comment(self, table_name: str) -> Dict:
"""Get table comments.
Args:
table_name (str): table name
Returns:
comment: Dict, which contains text: Optional[str], eg:["text": "comment"]
"""
cursor = self.session.execute(
text(
f"""
SELECT nvl(comment, table_name) as column_comment
FROM v_catalog.tables t
LEFT JOIN v_internal.vs_comments c ON t.table_id = c.objectoid
WHERE table_schema||'.'||table_nam e= '{table_name}'
"""
)
)
return {"text": cursor.scalar()}
def get_column_comments(self, db_name: str, table_name: str):
"""Return column comments."""
cursor = self.session.execute(
text(
f"""
SELECT column_name, nvl(comment, column_name) as column_comment
FROM v_catalog.columns c
LEFT JOIN v_internal.vs_sub_comments s ON c.table_id = s.objectoid
AND c.column_name = s.childobject
WHERE table_schema = '{db_name}' AND table_name = '{table_name}'
"""
)
)
column_comments = cursor.fetchall()
return [
(column_comment[0], column_comment[1]) for column_comment in column_comments
]
def get_database_names(self):
"""Get database names."""
session = self._db_sessions()
cursor = session.execute(text("SELECT schema_name FROM v_catalog.schemata;"))
results = cursor.fetchall()
return [d[0] for d in results if not d[0].startswith("v_")]
def get_current_db_name(self) -> str:
"""Get current database name."""
return self.session.execute(text("SELECT current_schema()")).scalar()
def table_simple_info(self):
"""Get table simple info."""
_sql = """
SELECT table_schema||'.'||table_name
, listagg(column_name using parameters max_length=65000)
FROM v_catalog.columns
WHERE table_schema NOT LIKE 'v\_%'
GROUP BY 1;
"""
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
return results
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
return []