feat: add postgresql support

This commit is contained in:
lozzow 2023-09-24 23:56:18 +00:00
parent c830598c9e
commit 8b7197d83a
5 changed files with 215 additions and 12 deletions

View File

@ -14,6 +14,7 @@ from pilot.connections.rdbms.conn_sqlite import SQLiteConnect
from pilot.connections.rdbms.conn_mssql import MSSQLConnect
from pilot.connections.rdbms.base import RDBMSDatabase
from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect
from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase
from pilot.singleton import Singleton
from pilot.common.sql_database import Database
from pilot.connections.db_conn_info import DBConfig

View File

@ -0,0 +1,197 @@
from typing import Iterable, Optional, Any
from sqlalchemy import text
from urllib.parse import quote
from pilot.connections.rdbms.base import RDBMSDatabase
class PostgreSQLDatabase(RDBMSDatabase):
driver = 'postgresql+psycopg2'
db_type = "postgresql"
db_dialect = 'postgresql'
@classmethod
def from_uri_db(
cls,
host: str,
port: int,
user: str,
pwd: str,
db_name: str,
engine_args: Optional[dict] = None,
**kwargs: Any,
) -> RDBMSDatabase:
db_url: str = (
cls.driver
+ "://"
+ quote(user)
+ ":"
+ quote(pwd)
+ "@"
+ host
+ ":"
+ str(port)
+ "/"
+ db_name
)
return cls.from_uri(db_url, engine_args, **kwargs)
def _sync_tables_from_db(self) -> Iterable[str]:
table_results = self.session.execute(
text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
)
view_results = self.session.execute(
text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'")
)
table_results = set(row[0] for row in table_results)
view_results = set(row[0] for row in view_results)
self._all_tables = table_results.union(view_results)
self._metadata.reflect(bind=self._engine)
return self._all_tables
def get_grants(self):
session = self._db_sessions()
cursor = session.execute(text(f"""
SELECT DISTINCT grantee, privilege_type
FROM information_schema.role_table_grants
WHERE grantee = CURRENT_USER;"""))
grants = cursor.fetchall()
return grants
def get_collation(self):
"""Get collation."""
session = self._db_sessions()
cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();"))
collation = cursor.fetchone()[0]
return collation
def get_users(self):
"""Get user info."""
try:
cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';"))
users = cursor.fetchall()
return [user[0] for user in users]
except Exception as e:
return []
def get_fields(self, table_name):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \
FROM information_schema.columns WHERE table_name = :table_name",
),
{"table_name": table_name},
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_charset(self):
"""Get character_set."""
session = self._db_sessions()
cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();"))
character_set = cursor.fetchone()[0]
return character_set
def get_show_create_table(self,table_name):
cur = self.session.execute(
text(
f"""
SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type
FROM pg_catalog.pg_attribute a
WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= (
SELECT max(a.attnum)
FROM pg_catalog.pg_attribute a
WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}')
"""
)
)
rows = cur.fetchall()
create_table_query = f"CREATE TABLE {table_name} (\n"
for row in rows:
create_table_query += f" {row[0]} {row[1]},\n"
create_table_query = create_table_query.rstrip(',\n') + "\n)"
return create_table_query
def get_table_comments(self, db_name=None):
tablses = self.table_simple_info()
comments = []
for table in tablses:
table_name = table[0]
table_comment = self.get_show_create_table(table_name)
comments.append((table_name, table_comment))
return comments
def get_database_list(self):
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["template0", "template1", "postgres"]
]
def get_database_names(self):
session = self._db_sessions()
cursor = session.execute(text("SELECT datname FROM pg_database;"))
results = cursor.fetchall()
return [
d[0]
for d in results
if d[0] not in ["template0", "template1", "postgres"]
]
def get_current_db_name(self) -> str:
return self.session.execute(text("SELECT current_database()")).scalar()
def table_simple_info(self):
_sql = f"""
SELECT table_name, string_agg(column_name, ', ') AS schema_info
FROM (
SELECT c.relname AS table_name, a.attname AS column_name
FROM pg_catalog.pg_class c
JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid
WHERE c.relkind = 'r'
AND a.attnum > 0
AND NOT a.attisdropped
AND n.nspname NOT LIKE 'pg_%'
AND n.nspname != 'information_schema'
ORDER BY c.relname, a.attnum
) sub
GROUP BY table_name;
"""
cursor = self.session.execute(text(_sql))
results = cursor.fetchall()
return results
def get_fields(self, table_name, schema_name='public'):
"""Get column fields about specified table."""
session = self._db_sessions()
cursor = session.execute(
text(
f"""
SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description
FROM information_schema.columns c
LEFT JOIN pg_catalog.pg_description d
ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid
WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}'
"""
)
)
fields = cursor.fetchall()
return [(field[0], field[1], field[2], field[3], field[4]) for field in fields]
def get_indexes(self, table_name):
"""Get table indexes about specified table."""
session = self._db_sessions()
cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'"))
indexes = cursor.fetchall()
return [(index[0], index[1]) for index in indexes]

View File

@ -21,9 +21,9 @@ class ChatDashboardOutputParser(BaseOutputParser):
super().__init__(sep=sep, is_stream_out=is_stream_out)
def parse_prompt_response(self, model_out_text):
clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", clean_str)
response = json.loads(clean_str)
# clean_str = super().parse_prompt_response(model_out_text)
print("clean prompt response:", model_out_text)
response = json.loads(model_out_text)
chart_items: List[ChartItem] = []
if not isinstance(response, list):
response = [response]

View File

@ -7,24 +7,28 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert, please provide a professional data analysis solution"
PROMPT_SCENE_DEFINE = "你是一个数据分析专家,请提供专业的数据分析解决方案"
_DEFAULT_TEMPLATE = """
According to the following table structure definition:
根据以下表结构定义
{table_info}
Provide professional data analysis to support users' goals:
提供专业的数据分析以支持用户的目标
{input}
Provide at least 4 and at most 8 dimensions of analysis according to user goals.
The output data of the analysis cannot exceed 4 columns, and do not use columns such as pay_status in the SQL where condition for data filtering.
According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for data display, chart type:
根据用户目标提供至少4个最多8个维度的分析
分析的输出数据不能超过4列不要在SQL where条件中使用如pay_status之类的列进行数据筛选
根据分析数据的特性从下面提供的图表中选择最合适的一种进行数据展示图表类型
{supported_chat_type}
Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens
注意分析结果的输出内容长度不要超过4000个令牌
Give the correct {dialect} analysis SQL (don't use unprovided values such as 'paid'), analysis title(don't exist the same), display method and summary of brief analysis thinking, and respond in the following json format:
给出正确的{dialect}分析SQL
1.不要使用未提供的值'paid'
2.所有查询的值必须是有别名的如select count(*) as count from table
3.如果表结构定义使用了{dialect}的关键字作为字段名需要使用转义符如select `count` from table
4.仔细检查SQL的正确性SQL必须是正确的显示方法和简要分析思路的总结并以以下json格式回应
{response}
Ensure the response is correct json and can be parsed by Python json.loads
做重要的额是:请确保只返回json字符串不要添加任何其他内容(用于程序直接处理),并且json并能被Python json.loads解析
"""
RESPONSE_FORMAT = [

View File

@ -295,6 +295,7 @@ def core_requires():
"langchain>=0.0.286",
"SQLAlchemy",
"pymysql",
"psycopg2"
"duckdb",
"duckdb-engine",
"jsonschema",