test:delete unusual test file

This commit is contained in:
aries_ckt 2023-07-06 19:10:09 +08:00
parent ec19a0e1df
commit b68bb2b5d5
12 changed files with 84 additions and 267 deletions

View File

@ -268,6 +268,31 @@ class Database:
result.insert(0, field_names)
return result
def query_ex(self, session, query, fetch: str = "all"):
"""
only for query
Args:
session:
query:
fetch:
Returns:
"""
print(f"Query[{query}]")
if not query:
return []
cursor = session.execute(text(query))
if cursor.returns_rows:
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0] # type: ignore
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
field_names = list(i[0:] for i in cursor.keys())
result = list(result)
return field_names, result
def run(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)

View File

@ -1,94 +0,0 @@
from pilot.configs.config import Config
import pandas as pd
from sqlalchemy import create_engine, pool
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.font_manager import FontProperties
from pyecharts.charts import Bar
from pyecharts import options as opts
from test_cls_1 import TestBase, Test1
from test_cls_2 import Test2
CFG = Config()
#
# if __name__ == "__main__":
# # 创建连接池
# engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
#
# # 从连接池中获取连接
#
#
# # 归还连接到连接池中
#
# # 执行SQL语句并将结果转化为DataFrame
# query = "SELECT * FROM users"
# df = pd.read_sql(query, engine.connect())
# df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
# # 导出为HTML文件
# with open('report.html', 'w') as f:
# f.write(df.style.render())
#
# # # 设置中文字体
# # font = FontProperties(fname='SimHei.ttf', size=14)
# #
# # colors = np.random.rand(df.shape[0])
# # df.plot.scatter(x='city', y='user_name', c=colors)
# # plt.show()
#
# # 查看DataFrame
# print(df.head())
#
#
# # 创建数据
# x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
# y_data = [820, 932, 901, 934, 1290, 1330, 1320]
#
# # 生成图表
# bar = (
# Bar()
# .add_xaxis(x_data)
# .add_yaxis("销售额", y_data)
# .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
# )
#
# # 生成HTML文件
# bar.render('report.html')
#
#
# if __name__ == "__main__":
# def __extract_json(s):
# i = s.index("{")
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
# for j, c in enumerate(s[i + 1 :], start=i + 1):
# if c == "}":
# count -= 1
# elif c == "{":
# count += 1
# if count == 0:
# break
# assert count == 0 # 检查是否找到最后一个'}'
# return s[i : j + 1]
#
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
# print(__extract_json(ss))
if __name__ == "__main__":
# test1 = Test1()
# test2 = Test2()
# test1.write()
# test1.test()
# test2.write()
# test1.test()
# test2.test()
# 定义包含元组的列表
data = [("key1", "value1"), ("key2", "value2"), ("key3", "value3")]
# 使用字典解析将列表转换为字典
result = {k: v for k, v in data}
print(result)

View File

@ -1,15 +0,0 @@
import json
from pilot.common.sql_database import Database
from pilot.configs.config import Config
CFG = Config()
if __name__ == "__main__":
# connect = CFG.local_db.get_session("gpt-user")
# datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
# print(datas)
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
print(str.find("["))

View File

@ -1,16 +0,0 @@
import json
import os
import duckdb
default_db_path = os.path.join(os.getcwd(), "message")
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
if __name__ == "__main__":
if os.path.isfile("../../../message/chat_history.db"):
cursor = duckdb.connect("../../../message/chat_history.db").cursor()
# cursor.execute("SELECT * FROM chat_history limit 20")
cursor.execute(
"SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'"
)
data = cursor.fetchall()
print(data)

View File

@ -1,92 +0,0 @@
from enum import Enum
from typing import List
class Test(Enum):
XXX = ("x", "1", True)
YYY = ("Y", "2", False)
ZZZ = ("Z", "3")
def __init__(self, code, v, flag=False):
self.code = code
self.v = v
self.flag = flag
class Scene:
def __init__(
self, code, name, describe, param_types: List = [], is_inner: bool = False
):
self.code = code
self.name = name
self.describe = describe
self.param_types = param_types
self.is_inner = is_inner
class ChatScene(Enum):
ChatWithDbExecute = Scene(
"chat_with_db_execute",
"Chat Data",
"Dialogue with your private data through natural language.",
["DB Select"],
)
ChatWithDbQA = Scene(
"chat_with_db_qa",
"Chat Meta Data",
"Have a Professional Conversation with Metadata.",
["DB Select"],
)
ChatExecution = Scene(
"chat_execution",
"Chat Plugin",
"Use tools through dialogue to accomplish your goals.",
["Plugin Select"],
)
ChatDefaultKnowledge = Scene(
"chat_default_knowledge",
"Chat Default Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
)
ChatNewKnowledge = Scene(
"chat_new_knowledge",
"Chat New Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Select"],
)
ChatUrlKnowledge = Scene(
"chat_url_knowledge",
"Chat URL",
"Dialogue through natural language and private documents and knowledge bases.",
["Url Input"],
)
InnerChatDBSummary = Scene(
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
)
ChatNormal = Scene(
"chat_normal", "Chat Normal", "Native LLM large model AI dialogue."
)
ChatDashboard = Scene(
"chat_dashboard",
"Chat Dashboard",
"Provide you with professional analysis reports through natural language.",
["DB Select"],
)
ChatKnowledge = Scene(
"chat_knowledge",
"Chat Knowledge",
"Dialogue through natural language and private documents and knowledge bases.",
["Knowledge Space Select"],
)
def scene_value(self):
return self.value.code
def scene_name(self):
return self._value_.name
if __name__ == "__main__":
print(ChatScene.ChatWithDbExecute.scene_value())
# print(ChatScene.ChatWithDbExecute.value.describe)

View File

@ -1,12 +0,0 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from test_cls_base import TestBase
class Test1(TestBase):
mode: str = "456"
def write(self):
self.test_values.append("x")
self.test_values.append("y")
self.test_values.append("g")

View File

@ -1,17 +0,0 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from test_cls_base import TestBase
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class Test2(TestBase):
test_2_values: List = []
mode: str = "789"
def write(self):
self.test_values.append(1)
self.test_values.append(2)
self.test_values.append(3)
self.test_2_values.append("x")
self.test_2_values.append("y")
self.test_2_values.append("z")

View File

@ -1,13 +0,0 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
class TestBase(BaseModel, ABC):
test_values: List = []
mode: str = "123"
def test(self):
print(self.__class__.__name__ + ":")
print(self.test_values)
print(self.mode)

View File

@ -2,6 +2,8 @@ import json
import os
import uuid
from typing import Dict, NamedTuple, List
from decimal import Decimal
from pilot.scene.base_message import (
HumanMessage,
ViewMessage,
@ -17,6 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
ChartData,
ReportData,
ValueItem,
)
CFG = Config()
@ -74,7 +77,40 @@ class ChatDashboard(BaseChat):
chart_datas: List[ChartData] = []
for chart_item in prompt_response:
try:
datas = self.database.run(self.db_connect, chart_item.sql)
field_names, datas = self.database.query_ex(
self.db_connect, chart_item.sql
)
values: List[ValueItem] = []
data_map = {}
field_map = {}
index = 0
for field_name in field_names:
data_map.update({f"{field_name}": [row[index] for row in datas]})
index += 1
if not data_map[field_name]:
field_map.update({f"{field_name}": False})
else:
field_map.update(
{
f"{field_name}": all(
isinstance(item, (int, float, Decimal))
for item in data_map[field_name]
)
}
)
for field_name in field_names[1:]:
if not field_map[field_name]:
print("more than 2 non-numeric column")
else:
for data in datas:
value_item = ValueItem(
name=data[0],
type=field_name,
value=data[field_names.index(field_name)],
)
values.append(value_item)
chart_datas.append(
ChartData(
chart_uid=str(uuid.uuid1()),
@ -82,8 +118,8 @@ class ChatDashboard(BaseChat):
chart_type=chart_item.showcase,
chart_desc=chart_item.thoughts,
chart_sql=chart_item.sql,
column_name=datas[0],
values=datas,
column_name=field_names,
values=values,
)
)
except Exception as e:

View File

@ -4,6 +4,15 @@ from typing import TypeVar, Union, List, Generic, Any
from dataclasses import dataclass, asdict
class ValueItem(BaseModel):
name: str
type: str = None
value: float
def dict(self, *args, **kwargs):
return {"name": self.name, "type": self.type, "value": self.value}
class ChartData(BaseModel):
chart_uid: str
chart_name: str
@ -11,7 +20,7 @@ class ChartData(BaseModel):
chart_desc: str
chart_sql: str
column_name: List
values: List
values: List[ValueItem]
style: Any = None
def dict(self, *args, **kwargs):
@ -22,7 +31,7 @@ class ChartData(BaseModel):
"chart_desc": self.chart_desc,
"chart_sql": self.chart_sql,
"column_name": [str(item) for item in self.column_name],
"values": [[str(item) for item in sublist] for sublist in self.values],
"values": [value.dict() for value in self.values],
"style": self.style,
}

View File

@ -13,8 +13,14 @@ PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provid
_DEFAULT_TEMPLATE = """
According to the structure definition in the following tables:
{table_info}
Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions.
Used to support goal: {input}
Provide professional data analysis to support the goal:
{input}
Constraint:
Provide multi-dimensional analysis as much as possible according to the target requirements, no less than three and no more than 8 dimensions.
The data columns of the analysis output should not exceed 4.
According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for display, chart type:
{supported_chat_type}
Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens
According to the characteristics of the analyzed data, choose the best one from the charts provided below to display, use different types of charts as much as possiblechart types:
@ -22,7 +28,7 @@ According to the characteristics of the analyzed data, choose the best one from
Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
{response}
Do not use unprovided fields and do not use unprovided data in the where condition of sql.
Do not use unprovided fields and value in the where condition of sql.
Ensure the response is correct json and can be parsed by Python json.loads
"""