mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 14:57:35 +00:00
feat: add postgresql support
This commit is contained in:
parent
c830598c9e
commit
8b7197d83a
@ -14,6 +14,7 @@ from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
|
||||
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
|
||||
from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase
|
||||
from pilot.singleton import Singleton
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.connections.db_conn_info import DBConfig
|
||||
|
197
pilot/connections/rdbms/conn_postgresql.py
Normal file
197
pilot/connections/rdbms/conn_postgresql.py
Normal file
@ -0,0 +1,197 @@
|
||||
from typing import Iterable, Optional, Any
|
||||
from sqlalchemy import text
|
||||
from urllib.parse import quote
|
||||
from pilot.connections.rdbms.base import RDBMSDatabase
|
||||
|
||||
|
||||
class PostgreSQLDatabase(RDBMSDatabase):
|
||||
driver = 'postgresql+psycopg2'
|
||||
db_type = "postgresql"
|
||||
db_dialect = 'postgresql'
|
||||
|
||||
@classmethod
|
||||
def from_uri_db(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
user: str,
|
||||
pwd: str,
|
||||
db_name: str,
|
||||
engine_args: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> RDBMSDatabase:
|
||||
db_url: str = (
|
||||
cls.driver
|
||||
+ "://"
|
||||
+ quote(user)
|
||||
+ ":"
|
||||
+ quote(pwd)
|
||||
+ "@"
|
||||
+ host
|
||||
+ ":"
|
||||
+ str(port)
|
||||
+ "/"
|
||||
+ db_name
|
||||
)
|
||||
return cls.from_uri(db_url, engine_args, **kwargs)
|
||||
|
||||
def _sync_tables_from_db(self) -> Iterable[str]:
|
||||
table_results = self.session.execute(
|
||||
text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
|
||||
)
|
||||
view_results = self.session.execute(
|
||||
text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
|
||||
)
|
||||
table_results = set(row[0] for row in table_results)
|
||||
view_results = set(row[0] for row in view_results)
|
||||
self._all_tables = table_results.union(view_results)
|
||||
self._metadata.reflect(bind=self._engine)
|
||||
return self._all_tables
|
||||
|
||||
|
||||
def get_grants(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"""
|
||||
SELECT DISTINCT grantee, privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = CURRENT_USER;"""))
|
||||
grants = cursor.fetchall()
|
||||
return grants
|
||||
|
||||
def get_collation(self):
|
||||
"""Get collation."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"))
|
||||
collation = cursor.fetchone()[0]
|
||||
return collation
|
||||
|
||||
def get_users(self):
|
||||
"""Get user info."""
|
||||
try:
|
||||
cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';"))
|
||||
users = cursor.fetchall()
|
||||
return [user[0] for user in users]
|
||||
except Exception as e:
|
||||
return []
|
||||
|
||||
def get_fields(self, table_name):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
|
||||
FROM information_schema.columns WHERE table_name = :table_name",
|
||||
),
|
||||
{"table_name": table_name},
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
def get_charset(self):
|
||||
"""Get character_set."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"))
|
||||
character_set = cursor.fetchone()[0]
|
||||
return character_set
|
||||
|
||||
|
||||
def get_show_create_table(self,table_name):
|
||||
cur = self.session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
|
||||
SELECT max(a.attnum)
|
||||
FROM pg_catalog.pg_attribute a
|
||||
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
|
||||
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
|
||||
"""
|
||||
)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
|
||||
create_table_query = f"CREATE TABLE {table_name} (\n"
|
||||
for row in rows:
|
||||
create_table_query += f" {row[0]} {row[1]},\n"
|
||||
create_table_query = create_table_query.rstrip(',\n') + "\n)"
|
||||
|
||||
return create_table_query
|
||||
|
||||
def get_table_comments(self, db_name=None):
|
||||
tablses = self.table_simple_info()
|
||||
comments = []
|
||||
for table in tablses:
|
||||
table_name = table[0]
|
||||
table_comment = self.get_show_create_table(table_name)
|
||||
comments.append((table_name, table_comment))
|
||||
return comments
|
||||
|
||||
def get_database_list(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
def get_database_names(self):
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text("SELECT datname FROM pg_database;"))
|
||||
results = cursor.fetchall()
|
||||
return [
|
||||
d[0]
|
||||
for d in results
|
||||
if d[0] not in ["template0", "template1", "postgres"]
|
||||
]
|
||||
|
||||
def get_current_db_name(self) -> str:
|
||||
return self.session.execute(text("SELECT current_database()")).scalar()
|
||||
|
||||
def table_simple_info(self):
|
||||
_sql = f"""
|
||||
SELECT table_name, string_agg(column_name, ', ') AS schema_info
|
||||
FROM (
|
||||
SELECT c.relname AS table_name, a.attname AS column_name
|
||||
FROM pg_catalog.pg_class c
|
||||
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
|
||||
JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
|
||||
WHERE c.relkind = 'r'
|
||||
AND a.attnum > 0
|
||||
AND NOT a.attisdropped
|
||||
AND n.nspname NOT LIKE 'pg_%'
|
||||
AND n.nspname != 'information_schema'
|
||||
ORDER BY c.relname, a.attnum
|
||||
) sub
|
||||
GROUP BY table_name;
|
||||
"""
|
||||
cursor = self.session.execute(text(_sql))
|
||||
results = cursor.fetchall()
|
||||
return results
|
||||
|
||||
def get_fields(self, table_name, schema_name='public'):
|
||||
"""Get column fields about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
|
||||
FROM information_schema.columns c
|
||||
LEFT JOIN pg_catalog.pg_description d
|
||||
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
|
||||
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
fields = cursor.fetchall()
|
||||
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
|
||||
|
||||
|
||||
def get_indexes(self, table_name):
|
||||
"""Get table indexes about specified table."""
|
||||
session = self._db_sessions()
|
||||
cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"))
|
||||
indexes = cursor.fetchall()
|
||||
return [(index[0], index[1]) for index in indexes]
|
@ -21,9 +21,9 @@ class ChatDashboardOutputParser(BaseOutputParser):
|
||||
super().__init__(sep=sep, is_stream_out=is_stream_out)
|
||||
|
||||
def parse_prompt_response(self, model_out_text):
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", clean_str)
|
||||
response = json.loads(clean_str)
|
||||
# clean_str = super().parse_prompt_response(model_out_text)
|
||||
print("clean prompt response:", model_out_text)
|
||||
response = json.loads(model_out_text)
|
||||
chart_items: List[ChartItem] = []
|
||||
if not isinstance(response, list):
|
||||
response = [response]
|
||||
|
@ -7,24 +7,28 @@ from pilot.common.schema import SeparatorStyle
|
||||
|
||||
CFG = Config()
|
||||
|
||||
PROMPT_SCENE_DEFINE = "You are a data analysis expert, please provide a professional data analysis solution"
|
||||
PROMPT_SCENE_DEFINE = "你是一个数据分析专家,请提供专业的数据分析解决方案"
|
||||
|
||||
_DEFAULT_TEMPLATE = """
|
||||
According to the following table structure definition:
|
||||
根据以下表结构定义:
|
||||
{table_info}
|
||||
Provide professional data analysis to support users' goals:
|
||||
提供专业的数据分析以支持用户的目标:
|
||||
{input}
|
||||
|
||||
Provide at least 4 and at most 8 dimensions of analysis according to user goals.
|
||||
The output data of the analysis cannot exceed 4 columns, and do not use columns such as pay_status in the SQL where condition for data filtering.
|
||||
According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for data display, chart type:
|
||||
根据用户目标,提供至少4个,最多8个维度的分析。
|
||||
分析的输出数据不能超过4列,不要在SQL where条件中使用如pay_status之类的列进行数据筛选。
|
||||
根据分析数据的特性,从下面提供的图表中选择最合适的一种进行数据展示,图表类型:
|
||||
{supported_chat_type}
|
||||
|
||||
Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens
|
||||
注意分析结果的输出内容长度,不要超过4000个令牌
|
||||
|
||||
Give the correct {dialect} analysis SQL (don't use unprovided values such as 'paid'), analysis title(don't exist the same), display method and summary of brief analysis thinking, and respond in the following json format:
|
||||
给出正确的{dialect}分析SQL
|
||||
1.不要使用未提供的值,如'paid'
|
||||
2.所有查询的值必须是有别名的,如select count(*) as count from table
|
||||
3.如果表结构定义使用了{dialect}的关键字作为字段名,需要使用转义符,如select `count` from table
|
||||
4.仔细检查SQL的正确性,SQL必须是正确的,显示方法和简要分析思路的总结,并以以下json格式回应:
|
||||
{response}
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
做重要的额是:请确保只返回json字符串,不要添加任何其他内容(用于程序直接处理),并且json并能被Python json.loads解析
|
||||
"""
|
||||
|
||||
RESPONSE_FORMAT = [
|
||||
|
Loading…
Reference in New Issue
Block a user