From 8b7197d83a89ff5d73e47d66f8cc329218dcf7eb Mon Sep 17 00:00:00 2001 From: lozzow Date: Sun, 24 Sep 2023 23:56:18 +0000 Subject: [PATCH 1/5] feat: add postgresql support --- .../connections/manages/connection_manager.py | 1 + pilot/connections/rdbms/conn_postgresql.py | 197 ++++++++++++++++++ pilot/scene/chat_dashboard/out_parser.py | 6 +- pilot/scene/chat_dashboard/prompt.py | 22 +- setup.py | 1 + 5 files changed, 215 insertions(+), 12 deletions(-) create mode 100644 pilot/connections/rdbms/conn_postgresql.py diff --git a/pilot/connections/manages/connection_manager.py b/pilot/connections/manages/connection_manager.py index 534cd36f0..070560cde 100644 --- a/pilot/connections/manages/connection_manager.py +++ b/pilot/connections/manages/connection_manager.py @@ -14,6 +14,7 @@ from pilot.connections.rdbms.conn_sqlite import SQLiteConnect from pilot.connections.rdbms.conn_mssql import MSSQLConnect from pilot.connections.rdbms.base import RDBMSDatabase from pilot.connections.rdbms.conn_clickhouse import ClickhouseConnect +from pilot.connections.rdbms.conn_postgresql import PostgreSQLDatabase from pilot.singleton import Singleton from pilot.common.sql_database import Database from pilot.connections.db_conn_info import DBConfig diff --git a/pilot/connections/rdbms/conn_postgresql.py b/pilot/connections/rdbms/conn_postgresql.py new file mode 100644 index 000000000..720f11f6d --- /dev/null +++ b/pilot/connections/rdbms/conn_postgresql.py @@ -0,0 +1,197 @@ +from typing import Iterable, Optional, Any +from sqlalchemy import text +from urllib.parse import quote +from pilot.connections.rdbms.base import RDBMSDatabase + + +class PostgreSQLDatabase(RDBMSDatabase): + driver = 'postgresql+psycopg2' + db_type = "postgresql" + db_dialect = 'postgresql' + + @classmethod + def from_uri_db( + cls, + host: str, + port: int, + user: str, + pwd: str, + db_name: str, + engine_args: Optional[dict] = None, + **kwargs: Any, + ) -> RDBMSDatabase: + db_url: str = ( + cls.driver + + "://" + + quote(user) + + ":" + + quote(pwd) + + "@" + + host + + ":" + + str(port) + + "/" + + db_name + ) + return cls.from_uri(db_url, engine_args, **kwargs) + + def _sync_tables_from_db(self) -> Iterable[str]: + table_results = self.session.execute( + text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") + ) + view_results = self.session.execute( + text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") + ) + table_results = set(row[0] for row in table_results) + view_results = set(row[0] for row in view_results) + self._all_tables = table_results.union(view_results) + self._metadata.reflect(bind=self._engine) + return self._all_tables + + + def get_grants(self): + session = self._db_sessions() + cursor = session.execute(text(f""" + SELECT DISTINCT grantee, privilege_type + FROM information_schema.role_table_grants + WHERE grantee = CURRENT_USER;""")) + grants = cursor.fetchall() + return grants + + def get_collation(self): + """Get collation.""" + session = self._db_sessions() + cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();")) + collation = cursor.fetchone()[0] + return collation + + def get_users(self): + """Get user info.""" + try: + cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")) + users = cursor.fetchall() + return [user[0] for user in users] + except Exception as e: + return [] + + def get_fields(self, table_name): + """Get column fields about specified table.""" + session = self._db_sessions() + cursor = session.execute( + text( + f"SELECT column_name, data_type, column_default, is_nullable, column_name as column_comment \ + FROM information_schema.columns WHERE table_name = :table_name", + ), + {"table_name": table_name}, + ) + fields = cursor.fetchall() + return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] + + def get_charset(self): + """Get character_set.""" + session = self._db_sessions() + cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();")) + character_set = cursor.fetchone()[0] + return character_set + + + def get_show_create_table(self,table_name): + cur = self.session.execute( + text( + f""" + SELECT a.attname as column_name, pg_catalog.format_type(a.atttypid, a.atttypmod) as data_type + FROM pg_catalog.pg_attribute a + WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attnum <= ( + SELECT max(a.attnum) + FROM pg_catalog.pg_attribute a + WHERE a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}') + ) AND a.attrelid = (SELECT oid FROM pg_catalog.pg_class WHERE relname='{table_name}') + """ + ) + ) + rows = cur.fetchall() + + create_table_query = f"CREATE TABLE {table_name} (\n" + for row in rows: + create_table_query += f" {row[0]} {row[1]},\n" + create_table_query = create_table_query.rstrip(',\n') + "\n)" + + return create_table_query + + def get_table_comments(self, db_name=None): + tablses = self.table_simple_info() + comments = [] + for table in tablses: + table_name = table[0] + table_comment = self.get_show_create_table(table_name) + comments.append((table_name, table_comment)) + return comments + + def get_database_list(self): + session = self._db_sessions() + cursor = session.execute(text("SELECT datname FROM pg_database;")) + results = cursor.fetchall() + return [ + d[0] + for d in results + if d[0] not in ["template0", "template1", "postgres"] + ] + + def get_database_names(self): + session = self._db_sessions() + cursor = session.execute(text("SELECT datname FROM pg_database;")) + results = cursor.fetchall() + return [ + d[0] + for d in results + if d[0] not in ["template0", "template1", "postgres"] + ] + + def get_current_db_name(self) -> str: + return self.session.execute(text("SELECT current_database()")).scalar() + + def table_simple_info(self): + _sql = f""" + SELECT table_name, string_agg(column_name, ', ') AS schema_info + FROM ( + SELECT c.relname AS table_name, a.attname AS column_name + FROM pg_catalog.pg_class c + JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace + JOIN pg_catalog.pg_attribute a ON a.attrelid = c.oid + WHERE c.relkind = 'r' + AND a.attnum > 0 + AND NOT a.attisdropped + AND n.nspname NOT LIKE 'pg_%' + AND n.nspname != 'information_schema' + ORDER BY c.relname, a.attnum + ) sub + GROUP BY table_name; + """ + cursor = self.session.execute(text(_sql)) + results = cursor.fetchall() + return results + + def get_fields(self, table_name, schema_name='public'): + """Get column fields about specified table.""" + session = self._db_sessions() + cursor = session.execute( + text( + f""" + SELECT c.column_name, c.data_type, c.column_default, c.is_nullable, d.description + FROM information_schema.columns c + LEFT JOIN pg_catalog.pg_description d + ON (c.table_schema || '.' || c.table_name)::regclass::oid = d.objoid AND c.ordinal_position = d.objsubid + WHERE c.table_name='{table_name}' AND c.table_schema='{schema_name}' + """ + ) + ) + fields = cursor.fetchall() + return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] + + + def get_indexes(self, table_name): + """Get table indexes about specified table.""" + session = self._db_sessions() + cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'")) + indexes = cursor.fetchall() + return [(index[0], index[1]) for index in indexes] diff --git a/pilot/scene/chat_dashboard/out_parser.py b/pilot/scene/chat_dashboard/out_parser.py index bf0fedef4..ac04c10e8 100644 --- a/pilot/scene/chat_dashboard/out_parser.py +++ b/pilot/scene/chat_dashboard/out_parser.py @@ -21,9 +21,9 @@ class ChatDashboardOutputParser(BaseOutputParser): super().__init__(sep=sep, is_stream_out=is_stream_out) def parse_prompt_response(self, model_out_text): - clean_str = super().parse_prompt_response(model_out_text) - print("clean prompt response:", clean_str) - response = json.loads(clean_str) + # clean_str = super().parse_prompt_response(model_out_text) + print("clean prompt response:", model_out_text) + response = json.loads(model_out_text) chart_items: List[ChartItem] = [] if not isinstance(response, list): response = [response] diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 5b93ebb1f..7b588e9e3 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -7,24 +7,28 @@ from pilot.common.schema import SeparatorStyle CFG = Config() -PROMPT_SCENE_DEFINE = "You are a data analysis expert, please provide a professional data analysis solution" +PROMPT_SCENE_DEFINE = "你是一个数据分析专家,请提供专业的数据分析解决方案" _DEFAULT_TEMPLATE = """ -According to the following table structure definition: +根据以下表结构定义: {table_info} -Provide professional data analysis to support users' goals: +提供专业的数据分析以支持用户的目标: {input} -Provide at least 4 and at most 8 dimensions of analysis according to user goals. -The output data of the analysis cannot exceed 4 columns, and do not use columns such as pay_status in the SQL where condition for data filtering. -According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for data display, chart type: +根据用户目标,提供至少4个,最多8个维度的分析。 +分析的输出数据不能超过4列,不要在SQL where条件中使用如pay_status之类的列进行数据筛选。 +根据分析数据的特性,从下面提供的图表中选择最合适的一种进行数据展示,图表类型: {supported_chat_type} -Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens +注意分析结果的输出内容长度,不要超过4000个令牌 -Give the correct {dialect} analysis SQL (don't use unprovided values such as 'paid'), analysis title(don't exist the same), display method and summary of brief analysis thinking, and respond in the following json format: +给出正确的{dialect}分析SQL +1.不要使用未提供的值,如'paid' +2.所有查询的值必须是有别名的,如select count(*) as count from table +3.如果表结构定义使用了{dialect}的关键字作为字段名,需要使用转义符,如select `count` from table +4.仔细检查SQL的正确性,SQL必须是正确的,显示方法和简要分析思路的总结,并以以下json格式回应: {response} -Ensure the response is correct json and can be parsed by Python json.loads +做重要的额是:请确保只返回json字符串,不要添加任何其他内容(用于程序直接处理),并且json并能被Python json.loads解析 """ RESPONSE_FORMAT = [ diff --git a/setup.py b/setup.py index 64ae3e8b4..734cbaa7e 100644 --- a/setup.py +++ b/setup.py @@ -295,6 +295,7 @@ def core_requires(): "langchain>=0.0.286", "SQLAlchemy", "pymysql", + "psycopg2" "duckdb", "duckdb-engine", "jsonschema", From 92a914732bed145aa8eaffef86974f4135ad09fe Mon Sep 17 00:00:00 2001 From: lozzow Date: Mon, 25 Sep 2023 12:18:08 +0000 Subject: [PATCH 2/5] fix: translate chinese to english --- pilot/scene/chat_dashboard/prompt.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 7b588e9e3..9fed97f8f 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -7,28 +7,28 @@ from pilot.common.schema import SeparatorStyle CFG = Config() -PROMPT_SCENE_DEFINE = "你是一个数据分析专家,请提供专业的数据分析解决方案" +PROMPT_SCENE_DEFINE = "You are a data analysis expert, please provide a professional data analysis solution" _DEFAULT_TEMPLATE = """ -根据以下表结构定义: +According to the following table structure definition: {table_info} -提供专业的数据分析以支持用户的目标: +Provide professional data analysis to support users' goals: {input} -根据用户目标,提供至少4个,最多8个维度的分析。 -分析的输出数据不能超过4列,不要在SQL where条件中使用如pay_status之类的列进行数据筛选。 -根据分析数据的特性,从下面提供的图表中选择最合适的一种进行数据展示,图表类型: +Provide at least 4 and at most 8 dimensions of analysis according to user goals. +The output data of the analysis cannot exceed 4 columns, and do not use columns such as pay_status in the SQL where condition for data filtering. +According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for data display, chart type: {supported_chat_type} -注意分析结果的输出内容长度,不要超过4000个令牌 +Pay attention to the length of the output content of the analysis result, do not exceed 4000 tokens -给出正确的{dialect}分析SQL -1.不要使用未提供的值,如'paid' -2.所有查询的值必须是有别名的,如select count(*) as count from table -3.如果表结构定义使用了{dialect}的关键字作为字段名,需要使用转义符,如select `count` from table -4.仔细检查SQL的正确性,SQL必须是正确的,显示方法和简要分析思路的总结,并以以下json格式回应: +Give the correct {dialect} analysis SQL +1.Do not use unprovided values such as 'paid' +2.All queried values must have aliases, such as select count(*) as count from table +3.If the table structure definition uses the keywords of {dialect} as field names, you need to use escape characters, such as select `count` from table +4.Carefully check the correctness of the SQL, the SQL must be correct, display method and summary of brief analysis thinking, and respond in the following json format: {response} -做重要的额是:请确保只返回json字符串,不要添加任何其他内容(用于程序直接处理),并且json并能被Python json.loads解析 +The important thing is: Please make sure to only return the json string, do not add any other content (for direct processing by the program), and the json can be parsed by Python json.loads """ RESPONSE_FORMAT = [ From eee2072acf28b3ad5397cb86da04323c629059fc Mon Sep 17 00:00:00 2001 From: lozzow Date: Mon, 25 Sep 2023 12:18:34 +0000 Subject: [PATCH 3/5] build: mv psycopg2 from framework to datasource --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 734cbaa7e..3521c8d8c 100644 --- a/setup.py +++ b/setup.py @@ -295,7 +295,6 @@ def core_requires(): "langchain>=0.0.286", "SQLAlchemy", "pymysql", - "psycopg2" "duckdb", "duckdb-engine", "jsonschema", @@ -364,7 +363,7 @@ def all_datasource_requires(): """ pip install "db-gpt[datasource]" """ - setup_spec.extras["datasource"] = ["pymssql", "pymysql"] + setup_spec.extras["datasource"] = ["pymssql", "pymysql","psycopg2"] def openai_requires(): From 6a03c24ee78b02c5eed220ed2b46a35e9dcdc150 Mon Sep 17 00:00:00 2001 From: lozzow Date: Mon, 25 Sep 2023 13:05:04 +0000 Subject: [PATCH 4/5] fix: add error message log out --- pilot/connections/rdbms/conn_postgresql.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pilot/connections/rdbms/conn_postgresql.py b/pilot/connections/rdbms/conn_postgresql.py index 720f11f6d..aab7e4ccc 100644 --- a/pilot/connections/rdbms/conn_postgresql.py +++ b/pilot/connections/rdbms/conn_postgresql.py @@ -60,10 +60,14 @@ class PostgreSQLDatabase(RDBMSDatabase): def get_collation(self): """Get collation.""" - session = self._db_sessions() - cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();")) - collation = cursor.fetchone()[0] - return collation + try: + session = self._db_sessions() + cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();")) + collation = cursor.fetchone()[0] + return collation + except Exception as e: + print("postgresql get collation error: ", e) + return None def get_users(self): """Get user info.""" @@ -72,6 +76,7 @@ class PostgreSQLDatabase(RDBMSDatabase): users = cursor.fetchall() return [user[0] for user in users] except Exception as e: + print("postgresql get users error: ", e) return [] def get_fields(self, table_name): From 14dc5dac04aaa05b4ffe7ad772dc5546de7f9b2b Mon Sep 17 00:00:00 2001 From: lozzow Date: Mon, 25 Sep 2023 14:29:16 +0000 Subject: [PATCH 5/5] style: fix code style with black --- pilot/connections/rdbms/conn_postgresql.py | 75 +++++++++++++--------- setup.py | 2 +- 2 files changed, 46 insertions(+), 31 deletions(-) diff --git a/pilot/connections/rdbms/conn_postgresql.py b/pilot/connections/rdbms/conn_postgresql.py index aab7e4ccc..2eeca4f04 100644 --- a/pilot/connections/rdbms/conn_postgresql.py +++ b/pilot/connections/rdbms/conn_postgresql.py @@ -5,9 +5,9 @@ from pilot.connections.rdbms.base import RDBMSDatabase class PostgreSQLDatabase(RDBMSDatabase): - driver = 'postgresql+psycopg2' + driver = "postgresql+psycopg2" db_type = "postgresql" - db_dialect = 'postgresql' + db_dialect = "postgresql" @classmethod def from_uri_db( @@ -34,13 +34,17 @@ class PostgreSQLDatabase(RDBMSDatabase): + db_name ) return cls.from_uri(db_url, engine_args, **kwargs) - + def _sync_tables_from_db(self) -> Iterable[str]: table_results = self.session.execute( - text("SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") + text( + "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + ) ) view_results = self.session.execute( - text("SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'") + text( + "SELECT viewname FROM pg_catalog.pg_views WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema'" + ) ) table_results = set(row[0] for row in table_results) view_results = set(row[0] for row in view_results) @@ -48,31 +52,40 @@ class PostgreSQLDatabase(RDBMSDatabase): self._metadata.reflect(bind=self._engine) return self._all_tables - def get_grants(self): session = self._db_sessions() - cursor = session.execute(text(f""" + cursor = session.execute( + text( + f""" SELECT DISTINCT grantee, privilege_type FROM information_schema.role_table_grants - WHERE grantee = CURRENT_USER;""")) + WHERE grantee = CURRENT_USER;""" + ) + ) grants = cursor.fetchall() return grants - + def get_collation(self): """Get collation.""" try: session = self._db_sessions() - cursor = session.execute(text("SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();")) + cursor = session.execute( + text( + "SELECT datcollate AS collation FROM pg_database WHERE datname = current_database();" + ) + ) collation = cursor.fetchone()[0] return collation except Exception as e: print("postgresql get collation error: ", e) return None - + def get_users(self): """Get user info.""" try: - cursor = self.session.execute(text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';")) + cursor = self.session.execute( + text("SELECT rolname FROM pg_roles WHERE rolname NOT LIKE 'pg_%';") + ) users = cursor.fetchall() return [user[0] for user in users] except Exception as e: @@ -91,16 +104,19 @@ class PostgreSQLDatabase(RDBMSDatabase): ) fields = cursor.fetchall() return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] - + def get_charset(self): """Get character_set.""" session = self._db_sessions() - cursor = session.execute(text("SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();")) + cursor = session.execute( + text( + "SELECT pg_encoding_to_char(encoding) FROM pg_database WHERE datname = current_database();" + ) + ) character_set = cursor.fetchone()[0] return character_set - - - def get_show_create_table(self,table_name): + + def get_show_create_table(self, table_name): cur = self.session.execute( text( f""" @@ -119,10 +135,10 @@ class PostgreSQLDatabase(RDBMSDatabase): create_table_query = f"CREATE TABLE {table_name} (\n" for row in rows: create_table_query += f" {row[0]} {row[1]},\n" - create_table_query = create_table_query.rstrip(',\n') + "\n)" + create_table_query = create_table_query.rstrip(",\n") + "\n)" return create_table_query - + def get_table_comments(self, db_name=None): tablses = self.table_simple_info() comments = [] @@ -131,15 +147,13 @@ class PostgreSQLDatabase(RDBMSDatabase): table_comment = self.get_show_create_table(table_name) comments.append((table_name, table_comment)) return comments - + def get_database_list(self): session = self._db_sessions() cursor = session.execute(text("SELECT datname FROM pg_database;")) results = cursor.fetchall() return [ - d[0] - for d in results - if d[0] not in ["template0", "template1", "postgres"] + d[0] for d in results if d[0] not in ["template0", "template1", "postgres"] ] def get_database_names(self): @@ -147,11 +161,9 @@ class PostgreSQLDatabase(RDBMSDatabase): cursor = session.execute(text("SELECT datname FROM pg_database;")) results = cursor.fetchall() return [ - d[0] - for d in results - if d[0] not in ["template0", "template1", "postgres"] + d[0] for d in results if d[0] not in ["template0", "template1", "postgres"] ] - + def get_current_db_name(self) -> str: return self.session.execute(text("SELECT current_database()")).scalar() @@ -176,7 +188,7 @@ class PostgreSQLDatabase(RDBMSDatabase): results = cursor.fetchall() return results - def get_fields(self, table_name, schema_name='public'): + def get_fields(self, table_name, schema_name="public"): """Get column fields about specified table.""" session = self._db_sessions() cursor = session.execute( @@ -193,10 +205,13 @@ class PostgreSQLDatabase(RDBMSDatabase): fields = cursor.fetchall() return [(field[0], field[1], field[2], field[3], field[4]) for field in fields] - def get_indexes(self, table_name): """Get table indexes about specified table.""" session = self._db_sessions() - cursor = session.execute(text(f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'")) + cursor = session.execute( + text( + f"SELECT indexname, indexdef FROM pg_indexes WHERE tablename = '{table_name}'" + ) + ) indexes = cursor.fetchall() return [(index[0], index[1]) for index in indexes] diff --git a/setup.py b/setup.py index 3521c8d8c..7a3b0b5b8 100644 --- a/setup.py +++ b/setup.py @@ -363,7 +363,7 @@ def all_datasource_requires(): """ pip install "db-gpt[datasource]" """ - setup_spec.extras["datasource"] = ["pymssql", "pymysql","psycopg2"] + setup_spec.extras["datasource"] = ["pymssql", "pymysql", "psycopg2"] def openai_requires():