From 5093e3714a1dc2563629f67c957328bc0ed59d04 Mon Sep 17 00:00:00 2001 From: "tuyang.yhj" Date: Thu, 6 Jul 2023 10:25:02 +0800 Subject: [PATCH] WEB API independent --- pilot/common/sql_database.py | 26 +++++++++++++++++ .../connections/rdbms/py_study/study_data.py | 7 +++-- pilot/scene/chat_dashboard/chat.py | 29 +++++++++++++++++-- .../data_preparation/report_schma.py | 15 ++++++++-- pilot/scene/chat_dashboard/prompt.py | 12 ++++++-- 5 files changed, 79 insertions(+), 10 deletions(-) diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index 501227873..8b6c5dfbe 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -268,6 +268,32 @@ class Database: result.insert(0, field_names) return result + def query_ex(self, session, query, fetch: str = "all"): + """ + only for query + Args: + session: + query: + fetch: + Returns: + """ + print(f"Query[{query}]") + if not query: + return [] + cursor = session.execute(text(query)) + if cursor.returns_rows: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone()[0] # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + field_names = list(i[0:] for i in cursor.keys()) + + result = list(result) + return field_names, result + + def run(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" print("SQL:" + command) diff --git a/pilot/connections/rdbms/py_study/study_data.py b/pilot/connections/rdbms/py_study/study_data.py index 6d6b88e04..c7245d1a7 100644 --- a/pilot/connections/rdbms/py_study/study_data.py +++ b/pilot/connections/rdbms/py_study/study_data.py @@ -10,8 +10,11 @@ if __name__ == "__main__": # print(datas) - str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}""" + # str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}""" + # + # print(str.find("[")) - print(str.find("[")) + test =["t1", "t2", "t3", "tx"] + print(str(test[1:])) diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 366edd56a..615808b0e 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -2,6 +2,8 @@ import json import os import uuid from typing import Dict, NamedTuple, List +from decimal import Decimal + from pilot.scene.base_message import ( HumanMessage, ViewMessage, @@ -17,6 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, ReportData, + ValueItem ) CFG = Config() @@ -77,14 +80,34 @@ class ChatDashboard(BaseChat): chart_datas: List[ChartData] = [] for chart_item in prompt_response: try: - datas = self.database.run(self.db_connect, chart_item.sql) + field_names, datas = self.database.query_ex(self.db_connect, chart_item.sql) + values: List[ValueItem] = [] + data_map = {} + field_map = {} + index = 0 + for field_name in field_names: + data_map.update({f"{field_name}": [row[index] for row in datas]}) + index += 1 + if not data_map[field_name]: + field_map.update({f"{field_name}": False}) + else: + field_map.update({f"{field_name}": all(isinstance(item, (int, float, Decimal)) for item in data_map[field_name])}) + + for field_name in field_names[1:]: + if not field_map[field_name]: + print("more than 2 non-numeric column") + else: + for data in datas: + value_item = ValueItem(name=data[0], type=field_name, value=data[field_names.index(field_name)]) + values.append(value_item) + chart_datas.append(ChartData(chart_uid=str(uuid.uuid1()), chart_name=chart_item.title, chart_type=chart_item.showcase, chart_desc=chart_item.thoughts, chart_sql=chart_item.sql, - column_name=datas[0], - values=datas)) + column_name=field_names, + values=values)) except Exception as e: # TODO 修复流程 print(str(e)) diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py index 4b9ca9f58..7c1ff8fb2 100644 --- a/pilot/scene/chat_dashboard/data_preparation/report_schma.py +++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py @@ -3,6 +3,17 @@ from pydantic import BaseModel, Field from typing import TypeVar, Union, List, Generic, Any from dataclasses import dataclass, asdict +class ValueItem(BaseModel): + name: str + type: str = None + value: float + def dict(self, *args, **kwargs): + return { + "name": self.name, + "type": self.type, + "value": self.value + } + class ChartData(BaseModel): chart_uid: str chart_name: str @@ -10,7 +21,7 @@ class ChartData(BaseModel): chart_desc: str chart_sql: str column_name: List - values: List + values: List[ValueItem] style: Any = None def dict(self, *args, **kwargs): @@ -21,7 +32,7 @@ class ChartData(BaseModel): "chart_desc": self.chart_desc, "chart_sql": self.chart_sql, "column_name": [str(item) for item in self.column_name], - "values": [[str(item) for item in sublist] for sublist in self.values], + "values": [value.dict() for value in self.values], "style": self.style } diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 00e40e790..212d24e96 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -13,8 +13,14 @@ PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provid _DEFAULT_TEMPLATE = """ According to the structure definition in the following tables: {table_info} -Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions. -Used to support goal: {input} +Provide professional data analysis to support the goal: +{input} + +Constraint: +Provide multi-dimensional analysis as much as possible according to the target requirements, no less than three and no more than 8 dimensions. +The data columns of the analysis output should not exceed 4. +According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for display, chart type: +{supported_chat_type} Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens According to the characteristics of the analyzed data, choose the best one from the charts provided below to display, use different types of charts as much as possible,chart types: @@ -22,7 +28,7 @@ According to the characteristics of the analyzed data, choose the best one from Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: {response} -Do not use unprovided fields and do not use unprovided data in the where condition of sql. +Do not use unprovided fields and value in the where condition of sql. Ensure the response is correct json and can be parsed by Python json.loads """