diff --git a/pilot/common/sql_database.py b/pilot/common/sql_database.py index d59a9d33f..e4fca79e3 100644 --- a/pilot/common/sql_database.py +++ b/pilot/common/sql_database.py @@ -268,6 +268,31 @@ class Database: result.insert(0, field_names) return result + def query_ex(self, session, query, fetch: str = "all"): + """ + only for query + Args: + session: + query: + fetch: + Returns: + """ + print(f"Query[{query}]") + if not query: + return [] + cursor = session.execute(text(query)) + if cursor.returns_rows: + if fetch == "all": + result = cursor.fetchall() + elif fetch == "one": + result = cursor.fetchone()[0] # type: ignore + else: + raise ValueError("Fetch parameter must be either 'one' or 'all'") + field_names = list(i[0:] for i in cursor.keys()) + + result = list(result) + return field_names, result + def run(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" print("SQL:" + command) diff --git a/pilot/connections/rdbms/py_study/__init__.py b/pilot/connections/rdbms/py_study/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pilot/connections/rdbms/py_study/pd_study.py b/pilot/connections/rdbms/py_study/pd_study.py deleted file mode 100644 index 810c7b55a..000000000 --- a/pilot/connections/rdbms/py_study/pd_study.py +++ /dev/null @@ -1,94 +0,0 @@ -from pilot.configs.config import Config -import pandas as pd -from sqlalchemy import create_engine, pool -import matplotlib.pyplot as plt -import numpy as np -from matplotlib.font_manager import FontProperties -from pyecharts.charts import Bar -from pyecharts import options as opts -from test_cls_1 import TestBase, Test1 -from test_cls_2 import Test2 - -CFG = Config() - -# -# if __name__ == "__main__": -# # 创建连接池 -# engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user') -# -# # 从连接池中获取连接 -# -# -# # 归还连接到连接池中 -# -# # 执行SQL语句并将结果转化为DataFrame -# query = "SELECT * FROM users" -# df = pd.read_sql(query, engine.connect()) -# df.style.set_properties(subset=['name'], **{'font-weight': 'bold'}) -# # 导出为HTML文件 -# with open('report.html', 'w') as f: -# f.write(df.style.render()) -# -# # # 设置中文字体 -# # font = FontProperties(fname='SimHei.ttf', size=14) -# # -# # colors = np.random.rand(df.shape[0]) -# # df.plot.scatter(x='city', y='user_name', c=colors) -# # plt.show() -# -# # 查看DataFrame -# print(df.head()) -# -# -# # 创建数据 -# x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] -# y_data = [820, 932, 901, 934, 1290, 1330, 1320] -# -# # 生成图表 -# bar = ( -# Bar() -# .add_xaxis(x_data) -# .add_yaxis("销售额", y_data) -# .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计")) -# ) -# -# # 生成HTML文件 -# bar.render('report.html') -# -# - - -# if __name__ == "__main__": - -# def __extract_json(s): -# i = s.index("{") -# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数 -# for j, c in enumerate(s[i + 1 :], start=i + 1): -# if c == "}": -# count -= 1 -# elif c == "{": -# count += 1 -# if count == 0: -# break -# assert count == 0 # 检查是否找到最后一个'}' -# return s[i : j + 1] -# -# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}""" -# print(__extract_json(ss)) - -if __name__ == "__main__": - # test1 = Test1() - # test2 = Test2() - # test1.write() - # test1.test() - # test2.write() - # test1.test() - # test2.test() - - # 定义包含元组的列表 - data = [("key1", "value1"), ("key2", "value2"), ("key3", "value3")] - - # 使用字典解析将列表转换为字典 - result = {k: v for k, v in data} - - print(result) diff --git a/pilot/connections/rdbms/py_study/study_data.py b/pilot/connections/rdbms/py_study/study_data.py deleted file mode 100644 index a5b4f9937..000000000 --- a/pilot/connections/rdbms/py_study/study_data.py +++ /dev/null @@ -1,15 +0,0 @@ -import json -from pilot.common.sql_database import Database -from pilot.configs.config import Config - -CFG = Config() - -if __name__ == "__main__": - # connect = CFG.local_db.get_session("gpt-user") - # datas = CFG.local_db.run(connect, "SELECT * FROM users; ") - - # print(datas) - - str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}""" - - print(str.find("[")) diff --git a/pilot/connections/rdbms/py_study/study_duckdb.py b/pilot/connections/rdbms/py_study/study_duckdb.py deleted file mode 100644 index 20e75f38c..000000000 --- a/pilot/connections/rdbms/py_study/study_duckdb.py +++ /dev/null @@ -1,16 +0,0 @@ -import json -import os -import duckdb - -default_db_path = os.path.join(os.getcwd(), "message") -duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db") - -if __name__ == "__main__": - if os.path.isfile("../../../message/chat_history.db"): - cursor = duckdb.connect("../../../message/chat_history.db").cursor() - # cursor.execute("SELECT * FROM chat_history limit 20") - cursor.execute( - "SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'" - ) - data = cursor.fetchall() - print(data) diff --git a/pilot/connections/rdbms/py_study/study_enum.py b/pilot/connections/rdbms/py_study/study_enum.py deleted file mode 100644 index 420d90c4e..000000000 --- a/pilot/connections/rdbms/py_study/study_enum.py +++ /dev/null @@ -1,92 +0,0 @@ -from enum import Enum -from typing import List - - -class Test(Enum): - XXX = ("x", "1", True) - YYY = ("Y", "2", False) - ZZZ = ("Z", "3") - - def __init__(self, code, v, flag=False): - self.code = code - self.v = v - self.flag = flag - - -class Scene: - def __init__( - self, code, name, describe, param_types: List = [], is_inner: bool = False - ): - self.code = code - self.name = name - self.describe = describe - self.param_types = param_types - self.is_inner = is_inner - - -class ChatScene(Enum): - ChatWithDbExecute = Scene( - "chat_with_db_execute", - "Chat Data", - "Dialogue with your private data through natural language.", - ["DB Select"], - ) - ChatWithDbQA = Scene( - "chat_with_db_qa", - "Chat Meta Data", - "Have a Professional Conversation with Metadata.", - ["DB Select"], - ) - ChatExecution = Scene( - "chat_execution", - "Chat Plugin", - "Use tools through dialogue to accomplish your goals.", - ["Plugin Select"], - ) - ChatDefaultKnowledge = Scene( - "chat_default_knowledge", - "Chat Default Knowledge", - "Dialogue through natural language and private documents and knowledge bases.", - ) - ChatNewKnowledge = Scene( - "chat_new_knowledge", - "Chat New Knowledge", - "Dialogue through natural language and private documents and knowledge bases.", - ["Knowledge Select"], - ) - ChatUrlKnowledge = Scene( - "chat_url_knowledge", - "Chat URL", - "Dialogue through natural language and private documents and knowledge bases.", - ["Url Input"], - ) - InnerChatDBSummary = Scene( - "inner_chat_db_summary", "DB Summary", "Db Summary.", True - ) - - ChatNormal = Scene( - "chat_normal", "Chat Normal", "Native LLM large model AI dialogue." - ) - ChatDashboard = Scene( - "chat_dashboard", - "Chat Dashboard", - "Provide you with professional analysis reports through natural language.", - ["DB Select"], - ) - ChatKnowledge = Scene( - "chat_knowledge", - "Chat Knowledge", - "Dialogue through natural language and private documents and knowledge bases.", - ["Knowledge Space Select"], - ) - - def scene_value(self): - return self.value.code - - def scene_name(self): - return self._value_.name - - -if __name__ == "__main__": - print(ChatScene.ChatWithDbExecute.scene_value()) - # print(ChatScene.ChatWithDbExecute.value.describe) diff --git a/pilot/connections/rdbms/py_study/test_cls_1.py b/pilot/connections/rdbms/py_study/test_cls_1.py deleted file mode 100644 index 1b91b5601..000000000 --- a/pilot/connections/rdbms/py_study/test_cls_1.py +++ /dev/null @@ -1,12 +0,0 @@ -from abc import ABC, abstractmethod -from pydantic import BaseModel -from test_cls_base import TestBase - - -class Test1(TestBase): - mode: str = "456" - - def write(self): - self.test_values.append("x") - self.test_values.append("y") - self.test_values.append("g") diff --git a/pilot/connections/rdbms/py_study/test_cls_2.py b/pilot/connections/rdbms/py_study/test_cls_2.py deleted file mode 100644 index 1fb4d5e88..000000000 --- a/pilot/connections/rdbms/py_study/test_cls_2.py +++ /dev/null @@ -1,17 +0,0 @@ -from abc import ABC, abstractmethod -from pydantic import BaseModel -from test_cls_base import TestBase -from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union - - -class Test2(TestBase): - test_2_values: List = [] - mode: str = "789" - - def write(self): - self.test_values.append(1) - self.test_values.append(2) - self.test_values.append(3) - self.test_2_values.append("x") - self.test_2_values.append("y") - self.test_2_values.append("z") diff --git a/pilot/connections/rdbms/py_study/test_cls_base.py b/pilot/connections/rdbms/py_study/test_cls_base.py deleted file mode 100644 index 676c1f2a5..000000000 --- a/pilot/connections/rdbms/py_study/test_cls_base.py +++ /dev/null @@ -1,13 +0,0 @@ -from abc import ABC, abstractmethod -from pydantic import BaseModel -from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union - - -class TestBase(BaseModel, ABC): - test_values: List = [] - mode: str = "123" - - def test(self): - print(self.__class__.__name__ + ":") - print(self.test_values) - print(self.mode) diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 50f10e360..468381ba2 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -2,6 +2,8 @@ import json import os import uuid from typing import Dict, NamedTuple, List +from decimal import Decimal + from pilot.scene.base_message import ( HumanMessage, ViewMessage, @@ -17,6 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt from pilot.scene.chat_dashboard.data_preparation.report_schma import ( ChartData, ReportData, + ValueItem, ) CFG = Config() @@ -74,7 +77,40 @@ class ChatDashboard(BaseChat): chart_datas: List[ChartData] = [] for chart_item in prompt_response: try: - datas = self.database.run(self.db_connect, chart_item.sql) + field_names, datas = self.database.query_ex( + self.db_connect, chart_item.sql + ) + values: List[ValueItem] = [] + data_map = {} + field_map = {} + index = 0 + for field_name in field_names: + data_map.update({f"{field_name}": [row[index] for row in datas]}) + index += 1 + if not data_map[field_name]: + field_map.update({f"{field_name}": False}) + else: + field_map.update( + { + f"{field_name}": all( + isinstance(item, (int, float, Decimal)) + for item in data_map[field_name] + ) + } + ) + + for field_name in field_names[1:]: + if not field_map[field_name]: + print("more than 2 non-numeric column") + else: + for data in datas: + value_item = ValueItem( + name=data[0], + type=field_name, + value=data[field_names.index(field_name)], + ) + values.append(value_item) + chart_datas.append( ChartData( chart_uid=str(uuid.uuid1()), @@ -82,8 +118,8 @@ class ChatDashboard(BaseChat): chart_type=chart_item.showcase, chart_desc=chart_item.thoughts, chart_sql=chart_item.sql, - column_name=datas[0], - values=datas, + column_name=field_names, + values=values, ) ) except Exception as e: diff --git a/pilot/scene/chat_dashboard/data_preparation/report_schma.py b/pilot/scene/chat_dashboard/data_preparation/report_schma.py index 6c0f54def..09388f4bd 100644 --- a/pilot/scene/chat_dashboard/data_preparation/report_schma.py +++ b/pilot/scene/chat_dashboard/data_preparation/report_schma.py @@ -4,6 +4,15 @@ from typing import TypeVar, Union, List, Generic, Any from dataclasses import dataclass, asdict +class ValueItem(BaseModel): + name: str + type: str = None + value: float + + def dict(self, *args, **kwargs): + return {"name": self.name, "type": self.type, "value": self.value} + + class ChartData(BaseModel): chart_uid: str chart_name: str @@ -11,7 +20,7 @@ class ChartData(BaseModel): chart_desc: str chart_sql: str column_name: List - values: List + values: List[ValueItem] style: Any = None def dict(self, *args, **kwargs): @@ -22,7 +31,7 @@ class ChartData(BaseModel): "chart_desc": self.chart_desc, "chart_sql": self.chart_sql, "column_name": [str(item) for item in self.column_name], - "values": [[str(item) for item in sublist] for sublist in self.values], + "values": [value.dict() for value in self.values], "style": self.style, } diff --git a/pilot/scene/chat_dashboard/prompt.py b/pilot/scene/chat_dashboard/prompt.py index 00e40e790..212d24e96 100644 --- a/pilot/scene/chat_dashboard/prompt.py +++ b/pilot/scene/chat_dashboard/prompt.py @@ -13,8 +13,14 @@ PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provid _DEFAULT_TEMPLATE = """ According to the structure definition in the following tables: {table_info} -Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions. -Used to support goal: {input} +Provide professional data analysis to support the goal: +{input} + +Constraint: +Provide multi-dimensional analysis as much as possible according to the target requirements, no less than three and no more than 8 dimensions. +The data columns of the analysis output should not exceed 4. +According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for display, chart type: +{supported_chat_type} Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens According to the characteristics of the analyzed data, choose the best one from the charts provided below to display, use different types of charts as much as possible,chart types: @@ -22,7 +28,7 @@ According to the characteristics of the analyzed data, choose the best one from Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format: {response} -Do not use unprovided fields and do not use unprovided data in the where condition of sql. +Do not use unprovided fields and value in the where condition of sql. Ensure the response is correct json and can be parsed by Python json.loads """