mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-31 15:47:05 +00:00
test:delete unusual test file
This commit is contained in:
parent
ec19a0e1df
commit
b68bb2b5d5
@ -268,6 +268,31 @@ class Database:
|
||||
result.insert(0, field_names)
|
||||
return result
|
||||
|
||||
def query_ex(self, session, query, fetch: str = "all"):
|
||||
"""
|
||||
only for query
|
||||
Args:
|
||||
session:
|
||||
query:
|
||||
fetch:
|
||||
Returns:
|
||||
"""
|
||||
print(f"Query[{query}]")
|
||||
if not query:
|
||||
return []
|
||||
cursor = session.execute(text(query))
|
||||
if cursor.returns_rows:
|
||||
if fetch == "all":
|
||||
result = cursor.fetchall()
|
||||
elif fetch == "one":
|
||||
result = cursor.fetchone()[0] # type: ignore
|
||||
else:
|
||||
raise ValueError("Fetch parameter must be either 'one' or 'all'")
|
||||
field_names = list(i[0:] for i in cursor.keys())
|
||||
|
||||
result = list(result)
|
||||
return field_names, result
|
||||
|
||||
def run(self, session, command: str, fetch: str = "all") -> List:
|
||||
"""Execute a SQL command and return a string representing the results."""
|
||||
print("SQL:" + command)
|
||||
|
@ -1,94 +0,0 @@
|
||||
from pilot.configs.config import Config
|
||||
import pandas as pd
|
||||
from sqlalchemy import create_engine, pool
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from matplotlib.font_manager import FontProperties
|
||||
from pyecharts.charts import Bar
|
||||
from pyecharts import options as opts
|
||||
from test_cls_1 import TestBase, Test1
|
||||
from test_cls_2 import Test2
|
||||
|
||||
CFG = Config()
|
||||
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# # 创建连接池
|
||||
# engine = create_engine('mysql+pymysql://root:aa123456@localhost:3306/gpt-user')
|
||||
#
|
||||
# # 从连接池中获取连接
|
||||
#
|
||||
#
|
||||
# # 归还连接到连接池中
|
||||
#
|
||||
# # 执行SQL语句并将结果转化为DataFrame
|
||||
# query = "SELECT * FROM users"
|
||||
# df = pd.read_sql(query, engine.connect())
|
||||
# df.style.set_properties(subset=['name'], **{'font-weight': 'bold'})
|
||||
# # 导出为HTML文件
|
||||
# with open('report.html', 'w') as f:
|
||||
# f.write(df.style.render())
|
||||
#
|
||||
# # # 设置中文字体
|
||||
# # font = FontProperties(fname='SimHei.ttf', size=14)
|
||||
# #
|
||||
# # colors = np.random.rand(df.shape[0])
|
||||
# # df.plot.scatter(x='city', y='user_name', c=colors)
|
||||
# # plt.show()
|
||||
#
|
||||
# # 查看DataFrame
|
||||
# print(df.head())
|
||||
#
|
||||
#
|
||||
# # 创建数据
|
||||
# x_data = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
|
||||
# y_data = [820, 932, 901, 934, 1290, 1330, 1320]
|
||||
#
|
||||
# # 生成图表
|
||||
# bar = (
|
||||
# Bar()
|
||||
# .add_xaxis(x_data)
|
||||
# .add_yaxis("销售额", y_data)
|
||||
# .set_global_opts(title_opts=opts.TitleOpts(title="销售额统计"))
|
||||
# )
|
||||
#
|
||||
# # 生成HTML文件
|
||||
# bar.render('report.html')
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
|
||||
# def __extract_json(s):
|
||||
# i = s.index("{")
|
||||
# count = 1 # 当前所在嵌套深度,即还没闭合的'{'个数
|
||||
# for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||||
# if c == "}":
|
||||
# count -= 1
|
||||
# elif c == "{":
|
||||
# count += 1
|
||||
# if count == 0:
|
||||
# break
|
||||
# assert count == 0 # 检查是否找到最后一个'}'
|
||||
# return s[i : j + 1]
|
||||
#
|
||||
# ss = """here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:select u.city, count(*) as order_countfrom tran_order oleft join user u on o.user_id = u.idgroup by u.city;this will return the number of orders for each city that has at least one order. we can use this data to generate a histogram that shows the distribution of orders across different cities.here's the response in the required format:{ "thoughts": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities:\n\nselect u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;", "speak": "here's a sql statement that can be used to generate a histogram to analyze the distribution of user orders in different cities.", "command": { "name": "histogram-executor", "args": { "title": "distribution of user orders in different cities", "sql": "select u.city, count(*) as order_count\nfrom tran_order o\nleft join user u on o.user_id = u.id\ngroup by u.city;" } }}"""
|
||||
# print(__extract_json(ss))
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test1 = Test1()
|
||||
# test2 = Test2()
|
||||
# test1.write()
|
||||
# test1.test()
|
||||
# test2.write()
|
||||
# test1.test()
|
||||
# test2.test()
|
||||
|
||||
# 定义包含元组的列表
|
||||
data = [("key1", "value1"), ("key2", "value2"), ("key3", "value3")]
|
||||
|
||||
# 使用字典解析将列表转换为字典
|
||||
result = {k: v for k, v in data}
|
||||
|
||||
print(result)
|
@ -1,15 +0,0 @@
|
||||
import json
|
||||
from pilot.common.sql_database import Database
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# connect = CFG.local_db.get_session("gpt-user")
|
||||
# datas = CFG.local_db.run(connect, "SELECT * FROM users; ")
|
||||
|
||||
# print(datas)
|
||||
|
||||
str = """{ "thoughts": "thought text", "sql": "SELECT COUNT(DISTINCT user_id) FROM transactions_order WHERE user_id IN (SELECT DISTINCT user_id FROM users WHERE country='China') AND create_time BETWEEN 20230101 AND 20230131" ,}"""
|
||||
|
||||
print(str.find("["))
|
@ -1,16 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
import duckdb
|
||||
|
||||
default_db_path = os.path.join(os.getcwd(), "message")
|
||||
duckdb_path = os.getenv("DB_DUCKDB_PATH", default_db_path + "/chat_history.db")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if os.path.isfile("../../../message/chat_history.db"):
|
||||
cursor = duckdb.connect("../../../message/chat_history.db").cursor()
|
||||
# cursor.execute("SELECT * FROM chat_history limit 20")
|
||||
cursor.execute(
|
||||
"SELECT * FROM chat_history where conv_uid ='b54ae5fe-1624-11ee-a271-b26789cc3e58'"
|
||||
)
|
||||
data = cursor.fetchall()
|
||||
print(data)
|
@ -1,92 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
|
||||
class Test(Enum):
|
||||
XXX = ("x", "1", True)
|
||||
YYY = ("Y", "2", False)
|
||||
ZZZ = ("Z", "3")
|
||||
|
||||
def __init__(self, code, v, flag=False):
|
||||
self.code = code
|
||||
self.v = v
|
||||
self.flag = flag
|
||||
|
||||
|
||||
class Scene:
|
||||
def __init__(
|
||||
self, code, name, describe, param_types: List = [], is_inner: bool = False
|
||||
):
|
||||
self.code = code
|
||||
self.name = name
|
||||
self.describe = describe
|
||||
self.param_types = param_types
|
||||
self.is_inner = is_inner
|
||||
|
||||
|
||||
class ChatScene(Enum):
|
||||
ChatWithDbExecute = Scene(
|
||||
"chat_with_db_execute",
|
||||
"Chat Data",
|
||||
"Dialogue with your private data through natural language.",
|
||||
["DB Select"],
|
||||
)
|
||||
ChatWithDbQA = Scene(
|
||||
"chat_with_db_qa",
|
||||
"Chat Meta Data",
|
||||
"Have a Professional Conversation with Metadata.",
|
||||
["DB Select"],
|
||||
)
|
||||
ChatExecution = Scene(
|
||||
"chat_execution",
|
||||
"Chat Plugin",
|
||||
"Use tools through dialogue to accomplish your goals.",
|
||||
["Plugin Select"],
|
||||
)
|
||||
ChatDefaultKnowledge = Scene(
|
||||
"chat_default_knowledge",
|
||||
"Chat Default Knowledge",
|
||||
"Dialogue through natural language and private documents and knowledge bases.",
|
||||
)
|
||||
ChatNewKnowledge = Scene(
|
||||
"chat_new_knowledge",
|
||||
"Chat New Knowledge",
|
||||
"Dialogue through natural language and private documents and knowledge bases.",
|
||||
["Knowledge Select"],
|
||||
)
|
||||
ChatUrlKnowledge = Scene(
|
||||
"chat_url_knowledge",
|
||||
"Chat URL",
|
||||
"Dialogue through natural language and private documents and knowledge bases.",
|
||||
["Url Input"],
|
||||
)
|
||||
InnerChatDBSummary = Scene(
|
||||
"inner_chat_db_summary", "DB Summary", "Db Summary.", True
|
||||
)
|
||||
|
||||
ChatNormal = Scene(
|
||||
"chat_normal", "Chat Normal", "Native LLM large model AI dialogue."
|
||||
)
|
||||
ChatDashboard = Scene(
|
||||
"chat_dashboard",
|
||||
"Chat Dashboard",
|
||||
"Provide you with professional analysis reports through natural language.",
|
||||
["DB Select"],
|
||||
)
|
||||
ChatKnowledge = Scene(
|
||||
"chat_knowledge",
|
||||
"Chat Knowledge",
|
||||
"Dialogue through natural language and private documents and knowledge bases.",
|
||||
["Knowledge Space Select"],
|
||||
)
|
||||
|
||||
def scene_value(self):
|
||||
return self.value.code
|
||||
|
||||
def scene_name(self):
|
||||
return self._value_.name
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(ChatScene.ChatWithDbExecute.scene_value())
|
||||
# print(ChatScene.ChatWithDbExecute.value.describe)
|
@ -1,12 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from test_cls_base import TestBase
|
||||
|
||||
|
||||
class Test1(TestBase):
|
||||
mode: str = "456"
|
||||
|
||||
def write(self):
|
||||
self.test_values.append("x")
|
||||
self.test_values.append("y")
|
||||
self.test_values.append("g")
|
@ -1,17 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from test_cls_base import TestBase
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
|
||||
class Test2(TestBase):
|
||||
test_2_values: List = []
|
||||
mode: str = "789"
|
||||
|
||||
def write(self):
|
||||
self.test_values.append(1)
|
||||
self.test_values.append(2)
|
||||
self.test_values.append(3)
|
||||
self.test_2_values.append("x")
|
||||
self.test_2_values.append("y")
|
||||
self.test_2_values.append("z")
|
@ -1,13 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
|
||||
|
||||
|
||||
class TestBase(BaseModel, ABC):
|
||||
test_values: List = []
|
||||
mode: str = "123"
|
||||
|
||||
def test(self):
|
||||
print(self.__class__.__name__ + ":")
|
||||
print(self.test_values)
|
||||
print(self.mode)
|
@ -2,6 +2,8 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, NamedTuple, List
|
||||
from decimal import Decimal
|
||||
|
||||
from pilot.scene.base_message import (
|
||||
HumanMessage,
|
||||
ViewMessage,
|
||||
@ -17,6 +19,7 @@ from pilot.scene.chat_dashboard.prompt import prompt
|
||||
from pilot.scene.chat_dashboard.data_preparation.report_schma import (
|
||||
ChartData,
|
||||
ReportData,
|
||||
ValueItem,
|
||||
)
|
||||
|
||||
CFG = Config()
|
||||
@ -74,7 +77,40 @@ class ChatDashboard(BaseChat):
|
||||
chart_datas: List[ChartData] = []
|
||||
for chart_item in prompt_response:
|
||||
try:
|
||||
datas = self.database.run(self.db_connect, chart_item.sql)
|
||||
field_names, datas = self.database.query_ex(
|
||||
self.db_connect, chart_item.sql
|
||||
)
|
||||
values: List[ValueItem] = []
|
||||
data_map = {}
|
||||
field_map = {}
|
||||
index = 0
|
||||
for field_name in field_names:
|
||||
data_map.update({f"{field_name}": [row[index] for row in datas]})
|
||||
index += 1
|
||||
if not data_map[field_name]:
|
||||
field_map.update({f"{field_name}": False})
|
||||
else:
|
||||
field_map.update(
|
||||
{
|
||||
f"{field_name}": all(
|
||||
isinstance(item, (int, float, Decimal))
|
||||
for item in data_map[field_name]
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
for field_name in field_names[1:]:
|
||||
if not field_map[field_name]:
|
||||
print("more than 2 non-numeric column")
|
||||
else:
|
||||
for data in datas:
|
||||
value_item = ValueItem(
|
||||
name=data[0],
|
||||
type=field_name,
|
||||
value=data[field_names.index(field_name)],
|
||||
)
|
||||
values.append(value_item)
|
||||
|
||||
chart_datas.append(
|
||||
ChartData(
|
||||
chart_uid=str(uuid.uuid1()),
|
||||
@ -82,8 +118,8 @@ class ChatDashboard(BaseChat):
|
||||
chart_type=chart_item.showcase,
|
||||
chart_desc=chart_item.thoughts,
|
||||
chart_sql=chart_item.sql,
|
||||
column_name=datas[0],
|
||||
values=datas,
|
||||
column_name=field_names,
|
||||
values=values,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -4,6 +4,15 @@ from typing import TypeVar, Union, List, Generic, Any
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
class ValueItem(BaseModel):
|
||||
name: str
|
||||
type: str = None
|
||||
value: float
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
return {"name": self.name, "type": self.type, "value": self.value}
|
||||
|
||||
|
||||
class ChartData(BaseModel):
|
||||
chart_uid: str
|
||||
chart_name: str
|
||||
@ -11,7 +20,7 @@ class ChartData(BaseModel):
|
||||
chart_desc: str
|
||||
chart_sql: str
|
||||
column_name: List
|
||||
values: List
|
||||
values: List[ValueItem]
|
||||
style: Any = None
|
||||
|
||||
def dict(self, *args, **kwargs):
|
||||
@ -22,7 +31,7 @@ class ChartData(BaseModel):
|
||||
"chart_desc": self.chart_desc,
|
||||
"chart_sql": self.chart_sql,
|
||||
"column_name": [str(item) for item in self.column_name],
|
||||
"values": [[str(item) for item in sublist] for sublist in self.values],
|
||||
"values": [value.dict() for value in self.values],
|
||||
"style": self.style,
|
||||
}
|
||||
|
||||
|
@ -13,8 +13,14 @@ PROMPT_SCENE_DEFINE = """You are a {dialect} data analysis expert, please provid
|
||||
_DEFAULT_TEMPLATE = """
|
||||
According to the structure definition in the following tables:
|
||||
{table_info}
|
||||
Provide professional data analysis, use as few dimensions as possible, but no less than three, and no more than eight dimensions.
|
||||
Used to support goal: {input}
|
||||
Provide professional data analysis to support the goal:
|
||||
{input}
|
||||
|
||||
Constraint:
|
||||
Provide multi-dimensional analysis as much as possible according to the target requirements, no less than three and no more than 8 dimensions.
|
||||
The data columns of the analysis output should not exceed 4.
|
||||
According to the characteristics of the analyzed data, choose the most suitable one from the charts provided below for display, chart type:
|
||||
{supported_chat_type}
|
||||
|
||||
Pay attention to the length of the output content of the analysis result, do not exceed 4000tokens
|
||||
According to the characteristics of the analyzed data, choose the best one from the charts provided below to display, use different types of charts as much as possible,chart types:
|
||||
@ -22,7 +28,7 @@ According to the characteristics of the analyzed data, choose the best one from
|
||||
|
||||
Give {dialect} data analysis SQL, analysis title, display method and analytical thinking,respond in the following json format:
|
||||
{response}
|
||||
Do not use unprovided fields and do not use unprovided data in the where condition of sql.
|
||||
Do not use unprovided fields and value in the where condition of sql.
|
||||
Ensure the response is correct json and can be parsed by Python json.loads
|
||||
"""
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user