style:fmt

This commit is contained in:
aries_ckt 2023-08-29 19:57:33 +08:00
parent 0efaffc031
commit 71b9cd14a6
36 changed files with 413 additions and 280 deletions

View File

@ -20,8 +20,17 @@ CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img") static_message_img_path = os.path.join(os.getcwd(), "message/img")
def zh_font_set(): def zh_font_set():
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -29,11 +38,14 @@ def zh_font_set():
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_line_chart",
"Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_line_chart(speak: str, df: DataFrame) -> str: def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},") logger.info(f"response_line_chart:{speak},")
@ -44,7 +56,15 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
# set font # set font
# zh_font_set() # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -52,31 +72,34 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts} rc = {"font.sans-serif": can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark") sns.set_style("dark")
sns.color_palette("hls", 10) sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7) sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context='notebook', style='ticks', rc=rc) sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax) sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
chart_name = "line_" + str(uuid.uuid1()) + ".png" chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_bar_chart(speak: str, df: DataFrame) -> str: def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},") logger.info(f"response_bar_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
@ -85,7 +108,15 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
# set font # set font
# zh_font_set() # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -93,29 +124,32 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts} rc = {"font.sans-serif": can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark") sns.set_style("dark")
sns.color_palette("hls", 10) sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7) sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context='notebook', style='ticks', rc=rc) sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax) sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
chart_name = "pie_" + str(uuid.uuid1()) + ".png" chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img return html_img
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_pie_chart(speak: str, df: DataFrame) -> str: def response_pie_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},") logger.info(f"response_pie_chart:{speak},")
columns = df.columns.tolist() columns = df.columns.tolist()
@ -123,7 +157,15 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
raise ValueError("No Data") raise ValueError("No Data")
# set font # set font
# zh_font_set() # zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager() fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist) mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = [] can_use_fonts = []
@ -131,23 +173,35 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts: if font_name in mat_fonts:
can_use_fonts.append(font_name) can_use_fonts.append(font_name)
if len(can_use_fonts) > 0: if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts plt.rcParams["font.sans-serif"] = can_use_fonts
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
# 手动设置 labels 的位置和大小 # 手动设置 labels 的位置和大小
ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10) ax.legend(
loc="upper right",
bbox_to_anchor=(0, 0, 1, 1),
labels=df[columns[0]].values,
fontsize=10,
)
plt.axis('equal') # 使饼图为正圆形 plt.axis("equal") # 使饼图为正圆形
# plt.title(columns[0]) # plt.title(columns[0])
chart_name = "pie_" + str(uuid.uuid1()) + ".png" chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100) plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />""" html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""

View File

@ -11,7 +11,11 @@ CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_table", "Table display, suitable for display with many display columns or non-numeric columns", '"speak": "<speak>", "df":"<data frame>"') @command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_table(speak: str, df: DataFrame) -> str: def response_table(speak: str, df: DataFrame) -> str:
logger.info(f"response_table:{speak}") logger.info(f"response_table:{speak}")
html_table = df.to_html(index=False, escape=False, sparsify=False) html_table = df.to_html(index=False, escape=False, sparsify=False)

View File

@ -10,8 +10,11 @@ CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_data_text", "Text display, the default display method, suitable for single-line or simple content display", @command(
'"speak": "<speak>", "df":"<data frame>"') "response_data_text",
"Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_data_text(speak: str, df: DataFrame) -> str: def response_data_text(speak: str, df: DataFrame) -> str:
logger.info(f"response_data_text:{speak}") logger.info(f"response_data_text:{speak}")
data = df.values data = df.values

View File

@ -1,6 +1,6 @@
def csv_colunm_foramt(val): def csv_colunm_foramt(val):
if str(val).find("$") >= 0: if str(val).find("$") >= 0:
return float(val.replace('$', '').replace(',', '')) return float(val.replace("$", "").replace(",", ""))
if str(val).find("¥") >= 0: if str(val).find("¥") >= 0:
return float(val.replace('¥', '').replace(',', '')) return float(val.replace("¥", "").replace(",", ""))
return val return val

View File

@ -325,6 +325,7 @@ class RDBMSDatabase(BaseConnect):
result = list(result) result = list(result)
return field_names, result return field_names, result
return [] return []
def run(self, command: str, fetch: str = "all") -> List: def run(self, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.""" """Execute a SQL command and return a string representing the results."""
print("SQL:" + command) print("SQL:" + command)

View File

@ -94,7 +94,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.commit() cursor.commit()
self.connect.commit() self.connect.commit()
def update(self, messages: List[OnceConversation]) -> None: def update(self, messages: List[OnceConversation]) -> None:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(
@ -161,7 +160,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return {} return {}
def get_messages(self) -> List[OnceConversation]: def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor() cursor = self.connect.cursor()
cursor.execute( cursor.execute(

View File

@ -110,7 +110,6 @@ async def db_connect_delete(db_name: str = None):
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) @router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types(): async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types() support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
db_type_infos = [] db_type_infos = []
for type in support_types: for type in support_types:
@ -130,7 +129,7 @@ async def dialogue_list(user_id: str = None):
chat_mode = item.get("chat_mode") chat_mode = item.get("chat_mode")
messages = json.loads(item.get("messages")) messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x['chat_order']) last_round = max(messages, key=lambda x: x["chat_order"])
if "param_value" in last_round: if "param_value" in last_round:
select_param = last_round["param_value"] select_param = last_round["param_value"]
else: else:
@ -139,7 +138,7 @@ async def dialogue_list(user_id: str = None):
conv_uid=conv_uid, conv_uid=conv_uid,
user_input=summary, user_input=summary,
chat_mode=chat_mode, chat_mode=chat_mode,
select_param=select_param select_param=select_param,
) )
dialogues.append(conv_vo) dialogues.append(conv_vo)
@ -213,7 +212,9 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
), ),
) )
## chat prepare ## chat prepare
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename) dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
)
chat: BaseChat = get_chat_instance(dialogue) chat: BaseChat = get_chat_instance(dialogue)
resp = chat.prepare() resp = chat.prepare()
@ -229,6 +230,7 @@ async def dialogue_delete(con_uid: str):
history_mem.delete() history_mem.delete()
return Result.succ(None) return Result.succ(None)
def get_hist_messages(conv_uid: str): def get_hist_messages(conv_uid: str):
message_vos: List[MessageVo] = [] message_vos: List[MessageVo] = []
history_mem = DuckdbHistoryMemory(conv_uid) history_mem = DuckdbHistoryMemory(conv_uid)
@ -264,7 +266,7 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = { chat_param = {
"chat_session_id": dialogue.conv_uid, "chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input, "user_input": dialogue.user_input,
"select_param": dialogue.select_param "select_param": dialogue.select_param,
} }
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
return chat return chat

View File

@ -22,7 +22,7 @@ from pilot.openapi.editor_view_model import (
ChartDetail, ChartDetail,
ChatChartEditContext, ChatChartEditContext,
ChatSqlEditContext, ChatSqlEditContext,
DbTable DbTable,
) )
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
@ -38,7 +38,9 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
@router.get("/v1/editor/db/tables", response_model=Result[DbTable]) @router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""): async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}") logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
tables = db_conn.get_table_names() tables = db_conn.get_table_names()
@ -49,8 +51,15 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
fields = db_conn.get_fields(table) fields = db_conn.get_fields(table)
for field in fields: for field in fields:
table_node.children.append( table_node.children.append(
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3], DataNode(
comment=field[-1])) title=field[0],
key=field[0],
type=field[1],
default_value=field[2],
can_null=field[3],
comment=field[-1],
)
)
return Result.succ(db_node) return Result.succ(db_node)
@ -68,8 +77,11 @@ async def get_editor_sql_rounds(con_uid: str):
if element["type"] == "human": if element["type"] == "human":
round_name = element["data"]["content"] round_name = element["data"]["content"]
if once.get("param_value"): if once.get("param_value"):
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"], round: ChatDbRounds = ChatDbRounds(
round_name=round_name) round=once["chat_order"],
db_name=once["param_value"],
round_name=round_name,
)
result.append(round) result.append(round)
return Result.succ(result) return Result.succ(result)
@ -84,8 +96,14 @@ async def get_editor_sql(con_uid: str, round: int):
if int(once["chat_order"]) == round: if int(once["chat_order"]) == round:
for element in once["messages"]: for element in once["messages"]:
if element["type"] == "ai": if element["type"] == "ai":
logger.info(f'history ai json resp:{element["data"]["content"]}') logger.info(
context = element["data"]["content"].replace("\\n", " ").replace("\n", " ") f'history ai json resp:{element["data"]["content"]}'
)
context = (
element["data"]["content"]
.replace("\\n", " ")
.replace("\n", " ")
)
return Result.succ(json.loads(context)) return Result.succ(json.loads(context))
return Result.faild(msg="not have sql!") return Result.faild(msg="not have sql!")
@ -93,8 +111,8 @@ async def get_editor_sql(con_uid: str, round: int):
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData]) @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()): async def editor_sql_run(run_param: dict = Body()):
logger.info(f"editor_sql_run:{run_param}") logger.info(f"editor_sql_run:{run_param}")
db_name = run_param['db_name'] db_name = run_param["db_name"]
sql = run_param['sql'] sql = run_param["sql"]
if not db_name and not sql: if not db_name and not sql:
return Result.faild("SQL run param error") return Result.faild("SQL run param error")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
@ -104,18 +122,17 @@ async def editor_sql_run(run_param: dict = Body()):
colunms, sql_result = conn.query_ex(sql) colunms, sql_result = conn.query_ex(sql)
# 计算执行耗时 # 计算执行耗时
end_time = time.time() * 1000 end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(result_info="", sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000, run_cost=(end_time - start_time) / 1000,
colunms=colunms, colunms=colunms,
values=sql_result values=sql_result,
) )
return Result.succ(sql_run_data) return Result.succ(sql_run_data)
except Exception as e: except Exception as e:
return Result.succ(SqlRunData(result_info=str(e), return Result.succ(
run_cost=0, SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
colunms=[], )
values=[]
))
@router.post("/v1/sql/editor/submit") @router.post("/v1/sql/editor/submit")
@ -126,18 +143,24 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
if history_messages: if history_messages:
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0] edit_round = list(
filter(
lambda x: x["chat_order"] == sql_edit_context.conv_round,
history_messages,
)
)[0]
if edit_round: if edit_round:
for element in edit_round["messages"]: for element in edit_round["messages"]:
if element["type"] == "ai": if element["type"] == "ai":
db_resp = json.loads(element["data"]["content"]) db_resp = json.loads(element["data"]["content"])
db_resp['thoughts'] = sql_edit_context.new_speak db_resp["thoughts"] = sql_edit_context.new_speak
db_resp['sql'] = sql_edit_context.new_sql db_resp["sql"] = sql_edit_context.new_sql
element["data"]["content"] = json.dumps(db_resp) element["data"]["content"] = json.dumps(db_resp)
if element["type"] == "view": if element["type"] == "view":
data_loader = DbDataLoader() data_loader = DbDataLoader()
element["data"]["content"] = data_loader.get_table_view_by_conn(conn.run(sql_edit_context.new_sql), element["data"]["content"] = data_loader.get_table_view_by_conn(
sql_edit_context.new_speak) conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
)
history_mem.update(history_messages) history_mem.update(history_messages)
return Result.succ(None) return Result.succ(None)
return Result.faild(msg="Edit Faild!") return Result.faild(msg="Edit Faild!")
@ -145,16 +168,21 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
@router.get("/v1/editor/chart/list", response_model=Result[ChartList]) @router.get("/v1/editor/chart/list", response_model=Result[ChartList])
async def get_editor_chart_list(con_uid: str): async def get_editor_chart_list(con_uid: str):
logger.info(f"get_editor_sql_rounds:{con_uid}", ) logger.info(
f"get_editor_sql_rounds:{con_uid}",
)
history_mem = DuckdbHistoryMemory(con_uid) history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order']) last_round = max(history_messages, key=lambda x: x["chat_order"])
db_name = last_round["param_value"] db_name = last_round["param_value"]
for element in last_round["messages"]: for element in last_round["messages"]:
if element["type"] == "ai": if element["type"] == "ai":
chart_list: ChartList = ChartList(round=last_round['chat_order'], db_name=db_name, chart_list: ChartList = ChartList(
charts=json.loads(element["data"]["content"])) round=last_round["chat_order"],
db_name=db_name,
charts=json.loads(element["data"]["content"]),
)
return Result.succ(chart_list) return Result.succ(chart_list)
return Result.faild(msg="Not have charts!") return Result.faild(msg="Not have charts!")
@ -162,32 +190,39 @@ async def get_editor_chart_list(con_uid: str):
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail]) @router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
async def get_editor_chart_info(param: dict = Body()): async def get_editor_chart_info(param: dict = Body()):
logger.info(f"get_editor_chart_info:{param}") logger.info(f"get_editor_chart_info:{param}")
conv_uid = param['con_uid'] conv_uid = param["con_uid"]
chart_title = param['chart_title'] chart_title = param["chart_title"]
history_mem = DuckdbHistoryMemory(conv_uid) history_mem = DuckdbHistoryMemory(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages() history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages: if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order']) last_round = max(history_messages, key=lambda x: x["chat_order"])
db_name = last_round["param_value"] db_name = last_round["param_value"]
if not db_name: if not db_name:
logger.error("this dashboard dialogue version too old, can't support editor!") logger.error(
return Result.faild(msg="this dashboard dialogue version too old, can't support editor!") "this dashboard dialogue version too old, can't support editor!"
)
return Result.faild(
msg="this dashboard dialogue version too old, can't support editor!"
)
for element in last_round["messages"]: for element in last_round["messages"]:
if element["type"] == "view": if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]); view_data: dict = json.loads(element["data"]["content"])
charts: List = view_data.get("charts") charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_title, charts))[0] find_chart = list(
filter(lambda x: x["chart_name"] == chart_title, charts)
)[0]
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'], detail: ChartDetail = ChartDetail(
chart_type=find_chart['chart_type'], chart_uid=find_chart["chart_uid"],
chart_desc=find_chart['chart_desc'], chart_type=find_chart["chart_type"],
chart_sql=find_chart['chart_sql'], chart_desc=find_chart["chart_desc"],
chart_sql=find_chart["chart_sql"],
db_name=db_name, db_name=db_name,
chart_name=find_chart['chart_name'], chart_name=find_chart["chart_name"],
chart_value=find_chart['values'], chart_value=find_chart["values"],
table_value=conn.run(find_chart['chart_sql']) table_value=conn.run(find_chart["chart_sql"]),
) )
return Result.succ(detail) return Result.succ(detail)
@ -197,36 +232,39 @@ async def get_editor_chart_info(param: dict = Body()):
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(run_param: dict = Body()): async def editor_chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{run_param}") logger.info(f"editor_chart_run:{run_param}")
db_name = run_param['db_name'] db_name = run_param["db_name"]
sql = run_param['sql'] sql = run_param["sql"]
chart_type = run_param['chart_type'] chart_type = run_param["chart_type"]
if not db_name and not sql: if not db_name and not sql:
return Result.faild("SQL run param error") return Result.faild("SQL run param error")
try: try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
colunms, sql_result = db_conn.query_ex(sql) colunms, sql_result = db_conn.query_ex(sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(colunms, sql_result, sql) field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
)
start_time = time.time() * 1000 start_time = time.time() * 1000
# 计算执行耗时 # 计算执行耗时
end_time = time.time() * 1000 end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(result_info="", sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000, run_cost=(end_time - start_time) / 1000,
colunms=colunms, colunms=colunms,
values=sql_result values=sql_result,
)
return Result.succ(
ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
)
) )
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values, chart_type = chart_type))
except Exception as e: except Exception as e:
sql_result = SqlRunData(result_info=str(e), sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
run_cost=0, return Result.succ(
colunms=[], ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
values=[]
) )
return Result.succ(ChartRunData(sql_data = sql_result,
chart_values=[],
chart_type = chart_type
))
@router.post("/v1/chart/editor/submit", response_model=Result[bool]) @router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()): async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
@ -237,35 +275,53 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name) db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
edit_round = max(history_messages, key=lambda x: x['chat_order']) edit_round = max(history_messages, key=lambda x: x["chat_order"])
if edit_round: if edit_round:
try: try:
for element in edit_round["messages"]: for element in edit_round["messages"]:
if element["type"] == "view": if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]); view_data: dict = json.loads(element["data"]["content"])
charts: List = view_data.get("charts") charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_title, charts))[ find_chart = list(
0] filter(
lambda x: x["chart_name"]
== chart_edit_context.chart_title,
charts,
)
)[0]
if chart_edit_context.new_chart_type: if chart_edit_context.new_chart_type:
find_chart['chart_type'] = chart_edit_context.new_chart_type find_chart["chart_type"] = chart_edit_context.new_chart_type
if chart_edit_context.new_comment: if chart_edit_context.new_comment:
find_chart['chart_desc'] = chart_edit_context.new_comment find_chart["chart_desc"] = chart_edit_context.new_comment
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, (
chart_edit_context.new_sql) field_names,
find_chart['chart_sql'] = chart_edit_context.new_sql chart_values,
find_chart['values'] = [value.dict() for value in chart_values] ) = dashboard_data_loader.get_chart_values_by_conn(
find_chart['column_name'] = field_names db_conn, chart_edit_context.new_sql
)
find_chart["chart_sql"] = chart_edit_context.new_sql
find_chart["values"] = [value.dict() for value in chart_values]
find_chart["column_name"] = field_names
element["data"]["content"] = json.dumps(view_data, ensure_ascii=False) element["data"]["content"] = json.dumps(
view_data, ensure_ascii=False
)
if element["type"] == "ai": if element["type"] == "ai":
ai_resp: dict = json.loads(element["data"]["content"]) ai_resp: dict = json.loads(element["data"]["content"])
edit_item = list(filter(lambda x: x['title'] == chart_edit_context.chart_title, ai_resp))[0] edit_item = list(
filter(
lambda x: x["title"] == chart_edit_context.chart_title,
ai_resp,
)
)[0]
edit_item["sql"] = chart_edit_context.new_sql edit_item["sql"] = chart_edit_context.new_sql
edit_item["showcase"] = chart_edit_context.new_chart_type edit_item["showcase"] = chart_edit_context.new_chart_type
edit_item["thoughts"] = chart_edit_context.new_comment edit_item["thoughts"] = chart_edit_context.new_comment
element["data"]["content"] = json.dumps(ai_resp, ensure_ascii=False) element["data"]["content"] = json.dumps(
ai_resp, ensure_ascii=False
)
except Exception as e: except Exception as e:
logger.error(f"edit chart exception!{str(e)}", e) logger.error(f"edit chart exception!{str(e)}", e)
return Result.faild(msg=f"Edit chart exception!{str(e)}") return Result.faild(msg=f"Edit chart exception!{str(e)}")

View File

@ -9,7 +9,7 @@ class DataNode(BaseModel):
type: str = "" type: str = ""
default_value: str = None default_value: str = None
can_null: str = 'YES' can_null: str = "YES"
comment: str = None comment: str = None
children: List = [] children: List = []

View File

@ -10,11 +10,13 @@ class DbField(BaseModel):
default_value: str = "" default_value: str = ""
comment: str = "" comment: str = ""
class DbTable(BaseModel): class DbTable(BaseModel):
table_name: str table_name: str
comment: str comment: str
colunm: List[DbField] colunm: List[DbField]
class ChatDbRounds(BaseModel): class ChatDbRounds(BaseModel):
round: int round: int
db_name: str db_name: str
@ -61,4 +63,3 @@ class ChatSqlEditContext(BaseModel):
new_sql: str new_sql: str
new_speak: str = "" new_speak: str = ""

View File

@ -123,11 +123,7 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\*", "*") ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "") ai_response = ai_response.replace("\t", "")
ai_response = ( ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
ai_response.strip()
.replace("\\n", " ")
.replace("\n", " ")
)
print("un_stream ai response:", ai_response) print("un_stream ai response:", ai_response)
return ai_response return ai_response
else: else:

View File

@ -12,7 +12,6 @@ class Scene:
is_inner: bool = False, is_inner: bool = False,
show_disable=False, show_disable=False,
prepare_scene_code: str = None, prepare_scene_code: str = None,
): ):
self.code = code self.code = code
self.name = name self.name = name
@ -22,6 +21,7 @@ class Scene:
self.show_disable = show_disable self.show_disable = show_disable
self.prepare_scene_code = prepare_scene_code self.prepare_scene_code = prepare_scene_code
class ChatScene(Enum): class ChatScene(Enum):
ChatWithDbExecute = Scene( ChatWithDbExecute = Scene(
code="chat_with_db_execute", code="chat_with_db_execute",
@ -40,7 +40,7 @@ class ChatScene(Enum):
name="Chat Excel", name="Chat Excel",
describe="Dialogue with your excel, use natural language.", describe="Dialogue with your excel, use natural language.",
param_types=["File Select"], param_types=["File Select"],
prepare_scene_code="excel_learning" prepare_scene_code="excel_learning",
) )
ChatWithDbQA = Scene( ChatWithDbQA = Scene(

View File

@ -60,11 +60,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True arbitrary_types_allowed = True
def __init__( def __init__(
self, self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
chat_mode,
chat_session_id,
current_user_input,
select_param: Any = None
): ):
self.chat_session_id = chat_session_id self.chat_session_id = chat_session_id
self.chat_mode = chat_mode self.chat_mode = chat_mode
@ -72,7 +68,6 @@ class BaseChat(ABC):
self.llm_model = CFG.LLM_MODEL self.llm_model = CFG.LLM_MODEL
self.llm_echo = False self.llm_echo = False
### load prompt template ### load prompt template
# self.prompt_template: PromptTemplate = CFG.prompt_templates[ # self.prompt_template: PromptTemplate = CFG.prompt_templates[
# self.chat_mode.value() # self.chat_mode.value()
@ -182,6 +177,7 @@ class BaseChat(ABC):
return response return response
else: else:
from pilot.server.llmserver import worker from pilot.server.llmserver import worker
return worker.generate_stream_gate(payload) return worker.generate_stream_gate(payload)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -235,7 +231,9 @@ class BaseChat(ABC):
### llm speaker ### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response) speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
self.current_message.add_view_message(view_message) self.current_message.add_view_message(view_message)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -253,11 +251,9 @@ class BaseChat(ABC):
else: else:
return self.nostream_call() return self.nostream_call()
def prepare(self): def prepare(self):
pass pass
def generate_llm_text(self) -> str: def generate_llm_text(self) -> str:
warnings.warn("This method is deprecated - please use `generate_llm_messages`.") warnings.warn("This method is deprecated - please use `generate_llm_messages`.")
text = "" text = ""
@ -363,9 +359,7 @@ class BaseChat(ABC):
) )
if len(self.history_message) > self.chat_retention_rounds: if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]["messages"]: for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [ if not first_message["type"] in [ModelMessageRoleType.VIEW]:
ModelMessageRoleType.VIEW
]:
message_type = first_message["type"] message_type = first_message["type"]
message_content = first_message["data"]["content"] message_content = first_message["data"]["content"]
history_text += ( history_text += (
@ -394,7 +388,9 @@ class BaseChat(ABC):
+ self.prompt_template.sep + self.prompt_template.sep
) )
history_messages.append( history_messages.append(
ModelMessage(role=message_type, content=message_content) ModelMessage(
role=message_type, content=message_content
)
) )
else: else:

View File

@ -63,6 +63,7 @@ class AIMessage(BaseMessage):
class ViewMessage(BaseMessage): class ViewMessage(BaseMessage):
"""Type of message that is spoken by the AI.""" """Type of message that is spoken by the AI."""
example: bool = False example: bool = False
@property @property
@ -73,6 +74,7 @@ class ViewMessage(BaseMessage):
class SystemMessage(BaseMessage): class SystemMessage(BaseMessage):
"""Type of message that is a system message.""" """Type of message that is a system message."""
@property @property
def type(self) -> str: def type(self) -> str:
"""Type of the message, used for serialization.""" """Type of the message, used for serialization."""

View File

@ -21,7 +21,13 @@ class ChatDashboard(BaseChat):
report_name: str report_name: str
"""Number of results to return from the query""" """Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = "", report_name:str="report"): def __init__(
self,
chat_session_id,
user_input,
select_param: str = "",
report_name: str = "report",
):
""" """ """ """
self.db_name = select_param self.db_name = select_param
super().__init__( super().__init__(
@ -80,7 +86,9 @@ class ChatDashboard(BaseChat):
dashboard_data_loader = DashboardDataLoader() dashboard_data_loader = DashboardDataLoader()
for chart_item in prompt_response: for chart_item in prompt_response:
try: try:
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.database, chart_item.sql) field_names, values = dashboard_data_loader.get_chart_values_by_conn(
self.database, chart_item.sql
)
chart_datas.append( chart_datas.append(
ChartData( ChartData(
chart_uid=str(uuid.uuid1()), chart_uid=str(uuid.uuid1()),
@ -101,5 +109,3 @@ class ChatDashboard(BaseChat):
template_introduce=None, template_introduce=None,
charts=chart_datas, charts=chart_datas,
) )

View File

@ -11,7 +11,6 @@ logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
class DashboardDataLoader: class DashboardDataLoader:
def get_sql_value(self, db_conn, chart_sql: str): def get_sql_value(self, db_conn, chart_sql: str):
return db_conn.query_ex(chart_sql) return db_conn.query_ex(chart_sql)

View File

@ -27,6 +27,7 @@ CFG = Config()
class ChatExcel(BaseChat): class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value() chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1 chat_retention_rounds = 1
def __init__(self, chat_session_id, user_input, select_param: str = ""): def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatExcel chat_mode = ChatScene.ChatExcel
@ -34,9 +35,11 @@ class ChatExcel(BaseChat):
if has_path(select_param): if has_path(select_param):
self.excel_reader = ExcelReader(select_param) self.excel_reader = ExcelReader(select_param)
else: else:
self.excel_reader = ExcelReader(os.path.join( self.excel_reader = ExcelReader(
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
)) )
)
super().__init__( super().__init__(
chat_mode=chat_mode, chat_mode=chat_mode,
@ -70,10 +73,7 @@ class ChatExcel(BaseChat):
] ]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self): def generate_input_values(self):
input_values = { input_values = {
"user_input": self.current_user_input, "user_input": self.current_user_input,
"table_name": self.excel_reader.table_name, "table_name": self.excel_reader.table_name,
@ -90,20 +90,18 @@ class ChatExcel(BaseChat):
"user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analysis", "user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analysis",
"parent_mode": self.chat_mode, "parent_mode": self.chat_mode,
"select_param": self.excel_reader.excel_file_name, "select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader "excel_reader": self.excel_reader,
} }
learn_chat = ExcelLearning(**chat_param) learn_chat = ExcelLearning(**chat_param)
result = learn_chat.nostream_call() result = learn_chat.nostream_call()
return result return result
def do_action(self, prompt_response): def do_action(self, prompt_response):
print(f"do_action:{prompt_response}") print(f"do_action:{prompt_response}")
# colunms, datas = self.excel_reader.run(prompt_response.sql) # colunms, datas = self.excel_reader.run(prompt_response.sql)
param = { param = {
"speak": prompt_response.thoughts, "speak": prompt_response.thoughts,
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql) "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
} }
return CFG.command_disply.call(prompt_response.display, **param) return CFG.command_disply.call(prompt_response.display, **param)

View File

@ -2,7 +2,9 @@ import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import ChatExcelOutputParser from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import (
ChatExcelOutputParser,
)
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@ -41,7 +43,7 @@ SQL中需要使用的表名是: {table_name}
RESPONSE_FORMAT_SIMPLE = { RESPONSE_FORMAT_SIMPLE = {
"sql": "analysis SQL", "sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis", "thoughts": "Current thinking and value of data analysis",
"display": "display type name" "display": "display type name",
} }
_DEFAULT_TEMPLATE = ( _DEFAULT_TEMPLATE = (
@ -72,4 +74,3 @@ prompt = PromptTemplate(
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -23,7 +23,14 @@ CFG = Config()
class ExcelLearning(BaseChat): class ExcelLearning(BaseChat):
chat_scene: str = ChatScene.ExcelLearning.value() chat_scene: str = ChatScene.ExcelLearning.value()
def __init__(self, chat_session_id, user_input, parent_mode: Any=None, select_param:str=None, excel_reader:Any=None): def __init__(
self,
chat_session_id,
user_input,
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
):
chat_mode = ChatScene.ExcelLearning chat_mode = ChatScene.ExcelLearning
""" """ """ """
self.excel_file_path = select_param self.excel_file_path = select_param
@ -38,13 +45,12 @@ class ExcelLearning(BaseChat):
self.current_message.chat_mode = parent_mode.value() self.current_message.chat_mode = parent_mode.value()
def generate_input_values(self): def generate_input_values(self):
colunms, datas = self.excel_reader.get_sample_data() colunms, datas = self.excel_reader.get_sample_data()
datas.insert(0, colunms) datas.insert(0, colunms)
input_values = { input_values = {
"data_example": json.dumps(self.excel_reader.get_sample_data(), cls=DateTimeEncoder), "data_example": json.dumps(
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
),
} }
return input_values return input_values

View File

@ -37,8 +37,6 @@ class LearningExcelOutputParser(BaseOutputParser):
plans = response[key] plans = response[key]
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans) return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
def parse_view_response(self, speak, data) -> str: def parse_view_response(self, speak, data) -> str:
### tool out data to table view ### tool out data to table view
html_title = f"### **数据简介**\n{data.desciption} " html_title = f"### **数据简介**\n{data.desciption} "
@ -48,7 +46,9 @@ class LearningExcelOutputParser(BaseOutputParser):
column_index += 1 column_index += 1
keys = item.keys() keys = item.keys()
for key in keys: for key in keys:
html_colunms = html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" html_colunms = (
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
)
html_plans = f"### **分析计划**\n" html_plans = f"### **分析计划**\n"
index = 0 index = 0

View File

@ -2,7 +2,9 @@ import json
from pilot.prompts.prompt_new import PromptTemplate from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.base import ChatScene from pilot.scene.base import ChatScene
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import LearningExcelOutputParser from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import (
LearningExcelOutputParser,
)
from pilot.common.schema import SeparatorStyle from pilot.common.schema import SeparatorStyle
CFG = Config() CFG = Config()
@ -63,5 +65,3 @@ prompt = PromptTemplate(
temperature=PROMPT_TEMPERATURE, temperature=PROMPT_TEMPERATURE,
) )
CFG.prompt_template_registry.register(prompt, is_default=True) CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -18,22 +18,35 @@ if __name__ == "__main__":
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
# colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) colunms, datas = excel_reader.run(
df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;") """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """
)
df = excel_reader.get_df_by_sql_ex(
"SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;"
)
columns = df.columns.tolist() columns = df.columns.tolist()
plt.rcParams["font.family"] = ["sans-serif"] plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) sns.set_style(rc={"font.sans-serif": "Microsoft Yahei"})
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
sns.set_palette("Set3") # 设置颜色主题 sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100) fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
plt.subplots_adjust(top=0.9) plt.subplots_adjust(top=0.9)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
# 手动设置 labels 的位置和大小 # 手动设置 labels 的位置和大小
ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10) ax.legend(
plt.axis('equal') # 使饼图为正圆形 loc="center left", bbox_to_anchor=(-1, 0.5, 0, 0), labels=None, fontsize=10
)
plt.axis("equal") # 使饼图为正圆形
plt.show() plt.show()
# #
# #

View File

@ -7,11 +7,13 @@ import numpy as np
from pilot.common.pd_utils import csv_colunm_foramt from pilot.common.pd_utils import csv_colunm_foramt
def excel_colunm_format(old_name: str) -> str: def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip() new_column = old_name.strip()
new_column = new_column.replace(" ", "_") new_column = new_column.replace(" ", "_")
return new_column return new_column
def add_quotes(sql, column_names=[]): def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "") sql = sql.replace("`", "")
parsed = sqlparse.parse(sql) parsed = sqlparse.parse(sql)
@ -20,6 +22,7 @@ def add_quotes(sql, column_names=[]):
deep_quotes(token, column_names) deep_quotes(token, column_names)
return str(parsed[0]) return str(parsed[0])
def deep_quotes(token, column_names=[]): def deep_quotes(token, column_names=[]):
if hasattr(token, "tokens"): if hasattr(token, "tokens"):
for token_child in token.tokens: for token_child in token.tokens:
@ -32,32 +35,38 @@ def deep_quotes(token, column_names=[]):
else: else:
token.value = f'"{token.value.replace("`", "")}"' token.value = f'"{token.value.replace("`", "")}"'
def is_chinese(string): def is_chinese(string):
# 使用正则表达式匹配中文字符 # 使用正则表达式匹配中文字符
pattern = re.compile(r'[一-龥]') pattern = re.compile(r"[一-龥]")
match = re.search(pattern, string) match = re.search(pattern, string)
return match is not None return match is not None
class ExcelReader: class ExcelReader:
def __init__(self, file_path): def __init__(self, file_path):
file_name = os.path.basename(file_path) file_name = os.path.basename(file_path)
file_name_without_extension = os.path.splitext(file_name)[0] file_name_without_extension = os.path.splitext(file_name)[0]
self.excel_file_name = file_name self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1] self.extension = os.path.splitext(file_name)[1]
# read excel file # read excel file
if file_path.endswith('.xlsx') or file_path.endswith('.xls'): if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
df_tmp = pd.read_excel(file_path) df_tmp = pd.read_excel(file_path)
self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) self.df = pd.read_excel(
elif file_path.endswith('.csv'): file_path,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
elif file_path.endswith(".csv"):
df_tmp = pd.read_csv(file_path) df_tmp = pd.read_csv(file_path)
self.df = pd.read_csv(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) self.df = pd.read_csv(
file_path,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else: else:
raise ValueError("Unsupported file format.") raise ValueError("Unsupported file format.")
self.df.replace('', np.nan, inplace=True) self.df.replace("", np.nan, inplace=True)
self.columns_map = {} self.columns_map = {}
for column_name in df_tmp.columns: for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)}) self.columns_map.update({column_name: excel_colunm_format(column_name)})
@ -66,11 +75,10 @@ class ExcelReader:
except Exception as e: except Exception as e:
print("transfor column error" + column_name) print("transfor column error" + column_name)
self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_')) self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_"))
# connect DuckDB # connect DuckDB
self.db = duckdb.connect(database=':memory:', read_only=False) self.db = duckdb.connect(database=":memory:", read_only=False)
self.table_name = file_name_without_extension self.table_name = file_name_without_extension
# write data in duckdb # write data in duckdb
@ -92,5 +100,4 @@ class ExcelReader:
return pd.DataFrame(values, columns=colunms) return pd.DataFrame(values, columns=colunms)
def get_sample_data(self): def get_sample_data(self):
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;') return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;")

View File

@ -47,7 +47,9 @@ class ChatWithDbAutoExecute(BaseChat):
client = DBSummaryClient() client = DBSummaryClient()
try: try:
table_infos = client.get_db_summary( table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE dbname=self.db_name,
query=self.current_user_input,
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
) )
except Exception as e: except Exception as e:
print("db summary find error!" + str(e)) print("db summary find error!" + str(e))

View File

@ -8,6 +8,7 @@ from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.scene.chat_db.data_loader import DbDataLoader from pilot.scene.chat_db.data_loader import DbDataLoader
CFG = Config() CFG = Config()
@ -50,5 +51,3 @@ class DbChatOutputParser(BaseOutputParser):
return view_text return view_text
else: else:
return data_loader.get_table_view_by_conn(data, speak) return data_loader.get_table_view_by_conn(data, speak)

View File

@ -1,8 +1,7 @@
import pandas as pd import pandas as pd
class DbDataLoader: class DbDataLoader:
def get_table_view_by_conn(self, data, speak): def get_table_view_by_conn(self, data, speak):
### tool out data to table view ### tool out data to table view
if len(data) <= 1: if len(data) <= 1:

View File

@ -20,12 +20,7 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator: PluginPromptGenerator plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None select_plugin: str = None
def __init__( def __init__(self, chat_session_id, user_input, select_param: str = None):
self,
chat_session_id,
user_input,
select_param: str = None
):
self.plugin_selector = select_param self.plugin_selector = select_param
super().__init__( super().__init__(
chat_mode=ChatScene.ChatExecution, chat_mode=ChatScene.ChatExecution,

View File

@ -30,7 +30,6 @@ class ChatNormal(BaseChat):
input_values = {"input": self.current_user_input} input_values = {"input": self.current_user_input}
return input_values return input_values
@property @property
def chat_type(self) -> str: def chat_type(self) -> str:
return ChatScene.ChatNormal.value return ChatScene.ChatNormal.value

View File

@ -117,12 +117,10 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"tokens": once.tokens if once.tokens else 0, "tokens": once.tokens if once.tokens else 0,
"messages": messages_to_dict(once.messages), "messages": messages_to_dict(once.messages),
"param_type": once.param_type, "param_type": once.param_type,
"param_value": once.param_value "param_value": once.param_value,
} }
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
return [_conversation_to_dic(m) for m in conversations] return [_conversation_to_dic(m) for m in conversations]

View File

@ -64,7 +64,6 @@ def server_init(args):
cfg.command_registry = command_registry cfg.command_registry = command_registry
command_disply_commands = [ command_disply_commands = [
"pilot.commands.disply_type.show_chart_gen", "pilot.commands.disply_type.show_chart_gen",
"pilot.commands.disply_type.show_table_gen", "pilot.commands.disply_type.show_table_gen",
@ -74,6 +73,3 @@ def server_init(args):
for command in command_disply_commands: for command in command_disply_commands:
command_disply_registry.import_commands(command) command_disply_registry.import_commands(command)
cfg.command_disply = command_disply_registry cfg.command_disply = command_disply_registry

View File

@ -81,7 +81,9 @@ app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1) # app.include_router(api_editor_route_v1)
os.makedirs(static_message_img_path, exist_ok=True) os.makedirs(static_message_img_path, exist_ok=True)
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2") app.mount(
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
)
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static")) app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")

View File

@ -19,7 +19,6 @@ from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log") logger = build_logger("db_summary", LOGDIR + "db_summary.log")
CFG = Config() CFG = Config()
chat_factory = ChatFactory() chat_factory = ChatFactory()
@ -145,7 +144,9 @@ class DBSummaryClient:
try: try:
self.db_summary_embedding(item["db_name"], item["db_type"]) self.db_summary_embedding(item["db_name"], item["db_type"])
except Exception as e: except Exception as e:
logger.warn(f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e) logger.warn(
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e
)
def init_db_profile(self, db_summary_client, dbname, embeddings): def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = { profile_store_config = {

View File

@ -75,7 +75,6 @@ def build_logger(logger_name, logger_filename):
item.addHandler(handler) item.addHandler(handler)
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# Get logger # Get logger
logger = logging.getLogger(logger_name) logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)