style:fmt

This commit is contained in:
aries_ckt 2023-08-29 19:57:33 +08:00
parent 0efaffc031
commit 71b9cd14a6
36 changed files with 413 additions and 280 deletions

View File

@ -20,8 +20,17 @@ CFG = Config()
logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log")
static_message_img_path = os.path.join(os.getcwd(), "message/img")
def zh_font_set():
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
@ -29,11 +38,14 @@ def zh_font_set():
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
plt.rcParams["font.sans-serif"] = can_use_fonts
@command("response_line_chart", "Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"')
@command(
"response_line_chart",
"Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},")
@ -44,7 +56,15 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
# set font
# zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
@ -52,31 +72,34 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7)
sns.set(context='notebook', style='ticks', rc=rc)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.lineplot(df, x=columns[0], y=columns[1], ax=ax)
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"')
@command(
"response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},")
columns = df.columns.tolist()
@ -85,7 +108,15 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
# set font
# zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
@ -93,29 +124,32 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {'font.sans-serif': can_use_fonts}
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=.5, s=.7)
sns.set(context='notebook', style='ticks', rc=rc)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax)
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"')
@command(
"response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_pie_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},")
columns = df.columns.tolist()
@ -123,7 +157,15 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
raise ValueError("No Data")
# set font
# zh_font_set()
font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi']
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
@ -131,23 +173,35 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams['font.sans-serif'] = can_use_fonts
plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题
plt.rcParams["font.sans-serif"] = can_use_fonts
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
# 手动设置 labels 的位置和大小
ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10)
ax.legend(
loc="upper right",
bbox_to_anchor=(0, 0, 1, 1),
labels=df[columns[0]].values,
fontsize=10,
)
plt.axis('equal') # 使饼图为正圆形
plt.axis("equal") # 使饼图为正圆形
# plt.title(columns[0])
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches='tight', dpi=100)
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""

View File

@ -11,8 +11,12 @@ CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_table", "Table display, suitable for display with many display columns or non-numeric columns", '"speak": "<speak>", "df":"<data frame>"')
def response_table(speak: str, df: DataFrame) -> str:
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_table(speak: str, df: DataFrame) -> str:
logger.info(f"response_table:{speak}")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())

View File

@ -10,9 +10,12 @@ CFG = Config()
logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command("response_data_text", "Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"')
def response_data_text(speak: str, df: DataFrame) -> str:
@command(
"response_data_text",
"Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"',
)
def response_data_text(speak: str, df: DataFrame) -> str:
logger.info(f"response_data_text:{speak}")
data = df.values

View File

@ -1,6 +1,6 @@
def csv_colunm_foramt(val):
if str(val).find("$") >= 0:
return float(val.replace('$', '').replace(',', ''))
return float(val.replace("$", "").replace(",", ""))
if str(val).find("¥") >= 0:
return float(val.replace('¥', '').replace(',', ''))
return float(val.replace("¥", "").replace(",", ""))
return val

View File

@ -262,7 +262,7 @@ class RDBMSDatabase(BaseConnect):
"""Format the error message"""
return f"Error: {e}"
def __write(self, write_sql):
def __write(self, write_sql):
print(f"Write[{write_sql}]")
db_cache = self._engine.url.database
result = self.session.execute(text(write_sql))
@ -272,7 +272,7 @@ class RDBMSDatabase(BaseConnect):
print(f"SQL[{write_sql}], result:{result.rowcount}")
return result.rowcount
def __query(self,query, fetch: str = "all"):
def __query(self, query, fetch: str = "all"):
"""
only for query
Args:
@ -325,6 +325,7 @@ class RDBMSDatabase(BaseConnect):
result = list(result)
return field_names, result
return []
def run(self, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results."""
print("SQL:" + command)
@ -333,12 +334,12 @@ class RDBMSDatabase(BaseConnect):
parsed, ttype, sql_type, table_name = self.__sql_parse(command)
if ttype == sqlparse.tokens.DML:
if sql_type == "SELECT":
return self.__query( command, fetch)
return self.__query(command, fetch)
else:
self.__write( command)
self.__write(command)
select_sql = self.convert_sql_write_to_select(command)
print(f"write result query:{select_sql}")
return self.__query( select_sql)
return self.__query(select_sql)
else:
print(f"DDL execution determines whether to enable through configuration ")
@ -351,10 +352,10 @@ class RDBMSDatabase(BaseConnect):
result.insert(0, field_names)
print("DDL Result:" + str(result))
if not result:
return self.__query( f"SHOW COLUMNS FROM {table_name}")
return self.__query(f"SHOW COLUMNS FROM {table_name}")
return result
else:
return self.__query( f"SHOW COLUMNS FROM {table_name}")
return self.__query(f"SHOW COLUMNS FROM {table_name}")
def run_no_throw(self, session, command: str, fetch: str = "all") -> List:
"""Execute a SQL command and return a string representing the results.

View File

@ -2,6 +2,6 @@ from pilot.configs.config import Config
from pilot.connections.manages.connection_manager import ConnectManager
if __name__ == "__main__":
mange= ConnectManager()
mange = ConnectManager()
types = mange.get_all_completed_types()
print(str(types))
print(str(types))

View File

@ -94,8 +94,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
cursor.commit()
self.connect.commit()
def update(self, messages:List[OnceConversation]) -> None:
def update(self, messages: List[OnceConversation]) -> None:
cursor = self.connect.cursor()
cursor.execute(
"UPDATE chat_history set messages=? where conv_uid=?",
@ -161,7 +160,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory):
return {}
def get_messages(self) -> List[OnceConversation]:
cursor = self.connect.cursor()
cursor.execute(

View File

@ -110,7 +110,6 @@ async def db_connect_delete(db_name: str = None):
@router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo])
async def db_support_types():
support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types()
db_type_infos = []
for type in support_types:
@ -130,7 +129,7 @@ async def dialogue_list(user_id: str = None):
chat_mode = item.get("chat_mode")
messages = json.loads(item.get("messages"))
last_round = max(messages, key=lambda x: x['chat_order'])
last_round = max(messages, key=lambda x: x["chat_order"])
if "param_value" in last_round:
select_param = last_round["param_value"]
else:
@ -139,7 +138,7 @@ async def dialogue_list(user_id: str = None):
conv_uid=conv_uid,
user_input=summary,
chat_mode=chat_mode,
select_param=select_param
select_param=select_param,
)
dialogues.append(conv_vo)
@ -213,7 +212,9 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File
),
)
## chat prepare
dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename)
dialogue = ConversationVo(
conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename
)
chat: BaseChat = get_chat_instance(dialogue)
resp = chat.prepare()
@ -229,7 +230,8 @@ async def dialogue_delete(con_uid: str):
history_mem.delete()
return Result.succ(None)
def get_hist_messages(conv_uid:str):
def get_hist_messages(conv_uid: str):
message_vos: List[MessageVo] = []
history_mem = DuckdbHistoryMemory(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
@ -264,7 +266,7 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat:
chat_param = {
"chat_session_id": dialogue.conv_uid,
"user_input": dialogue.user_input,
"select_param": dialogue.select_param
"select_param": dialogue.select_param,
}
chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param)
return chat

View File

@ -22,7 +22,7 @@ from pilot.openapi.editor_view_model import (
ChartDetail,
ChatChartEditContext,
ChatSqlEditContext,
DbTable
DbTable,
)
from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData
@ -38,7 +38,9 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log")
@router.get("/v1/editor/db/tables", response_model=Result[DbTable])
async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""):
async def get_editor_tables(
db_name: str, page_index: int, page_size: int, search_str: str = ""
):
logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
tables = db_conn.get_table_names()
@ -49,8 +51,15 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc
fields = db_conn.get_fields(table)
for field in fields:
table_node.children.append(
DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3],
comment=field[-1]))
DataNode(
title=field[0],
key=field[0],
type=field[1],
default_value=field[2],
can_null=field[3],
comment=field[-1],
)
)
return Result.succ(db_node)
@ -68,8 +77,11 @@ async def get_editor_sql_rounds(con_uid: str):
if element["type"] == "human":
round_name = element["data"]["content"]
if once.get("param_value"):
round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"],
round_name=round_name)
round: ChatDbRounds = ChatDbRounds(
round=once["chat_order"],
db_name=once["param_value"],
round_name=round_name,
)
result.append(round)
return Result.succ(result)
@ -84,8 +96,14 @@ async def get_editor_sql(con_uid: str, round: int):
if int(once["chat_order"]) == round:
for element in once["messages"]:
if element["type"] == "ai":
logger.info(f'history ai json resp:{element["data"]["content"]}')
context = element["data"]["content"].replace("\\n", " ").replace("\n", " ")
logger.info(
f'history ai json resp:{element["data"]["content"]}'
)
context = (
element["data"]["content"]
.replace("\\n", " ")
.replace("\n", " ")
)
return Result.succ(json.loads(context))
return Result.faild(msg="not have sql!")
@ -93,8 +111,8 @@ async def get_editor_sql(con_uid: str, round: int):
@router.post("/v1/editor/sql/run", response_model=Result[SqlRunData])
async def editor_sql_run(run_param: dict = Body()):
logger.info(f"editor_sql_run:{run_param}")
db_name = run_param['db_name']
sql = run_param['sql']
db_name = run_param["db_name"]
sql = run_param["sql"]
if not db_name and not sql:
return Result.faild("SQL run param error")
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
@ -104,18 +122,17 @@ async def editor_sql_run(run_param: dict = Body()):
colunms, sql_result = conn.query_ex(sql)
# 计算执行耗时
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(result_info="",
run_cost=(end_time - start_time) / 1000,
colunms=colunms,
values=sql_result
)
sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000,
colunms=colunms,
values=sql_result,
)
return Result.succ(sql_run_data)
except Exception as e:
return Result.succ(SqlRunData(result_info=str(e),
run_cost=0,
colunms=[],
values=[]
))
return Result.succ(
SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
)
@router.post("/v1/sql/editor/submit")
@ -126,18 +143,24 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
if history_messages:
conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name)
edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0]
edit_round = list(
filter(
lambda x: x["chat_order"] == sql_edit_context.conv_round,
history_messages,
)
)[0]
if edit_round:
for element in edit_round["messages"]:
if element["type"] == "ai":
db_resp = json.loads(element["data"]["content"])
db_resp['thoughts'] = sql_edit_context.new_speak
db_resp['sql'] = sql_edit_context.new_sql
db_resp["thoughts"] = sql_edit_context.new_speak
db_resp["sql"] = sql_edit_context.new_sql
element["data"]["content"] = json.dumps(db_resp)
if element["type"] == "view":
data_loader = DbDataLoader()
element["data"]["content"] = data_loader.get_table_view_by_conn(conn.run(sql_edit_context.new_sql),
sql_edit_context.new_speak)
element["data"]["content"] = data_loader.get_table_view_by_conn(
conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak
)
history_mem.update(history_messages)
return Result.succ(None)
return Result.faild(msg="Edit Faild!")
@ -145,16 +168,21 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()):
@router.get("/v1/editor/chart/list", response_model=Result[ChartList])
async def get_editor_chart_list(con_uid: str):
logger.info(f"get_editor_sql_rounds:{con_uid}", )
logger.info(
f"get_editor_sql_rounds:{con_uid}",
)
history_mem = DuckdbHistoryMemory(con_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order'])
last_round = max(history_messages, key=lambda x: x["chat_order"])
db_name = last_round["param_value"]
for element in last_round["messages"]:
if element["type"] == "ai":
chart_list: ChartList = ChartList(round=last_round['chat_order'], db_name=db_name,
charts=json.loads(element["data"]["content"]))
chart_list: ChartList = ChartList(
round=last_round["chat_order"],
db_name=db_name,
charts=json.loads(element["data"]["content"]),
)
return Result.succ(chart_list)
return Result.faild(msg="Not have charts!")
@ -162,33 +190,40 @@ async def get_editor_chart_list(con_uid: str):
@router.post("/v1/editor/chart/info", response_model=Result[ChartDetail])
async def get_editor_chart_info(param: dict = Body()):
logger.info(f"get_editor_chart_info:{param}")
conv_uid = param['con_uid']
chart_title = param['chart_title']
conv_uid = param["con_uid"]
chart_title = param["chart_title"]
history_mem = DuckdbHistoryMemory(conv_uid)
history_messages: List[OnceConversation] = history_mem.get_messages()
if history_messages:
last_round = max(history_messages, key=lambda x: x['chat_order'])
last_round = max(history_messages, key=lambda x: x["chat_order"])
db_name = last_round["param_value"]
if not db_name:
logger.error("this dashboard dialogue version too old, can't support editor!")
return Result.faild(msg="this dashboard dialogue version too old, can't support editor!")
logger.error(
"this dashboard dialogue version too old, can't support editor!"
)
return Result.faild(
msg="this dashboard dialogue version too old, can't support editor!"
)
for element in last_round["messages"]:
if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]);
view_data: dict = json.loads(element["data"]["content"])
charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_title, charts))[0]
find_chart = list(
filter(lambda x: x["chart_name"] == chart_title, charts)
)[0]
conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'],
chart_type=find_chart['chart_type'],
chart_desc=find_chart['chart_desc'],
chart_sql=find_chart['chart_sql'],
db_name=db_name,
chart_name=find_chart['chart_name'],
chart_value=find_chart['values'],
table_value=conn.run(find_chart['chart_sql'])
)
detail: ChartDetail = ChartDetail(
chart_uid=find_chart["chart_uid"],
chart_type=find_chart["chart_type"],
chart_desc=find_chart["chart_desc"],
chart_sql=find_chart["chart_sql"],
db_name=db_name,
chart_name=find_chart["chart_name"],
chart_value=find_chart["values"],
table_value=conn.run(find_chart["chart_sql"]),
)
return Result.succ(detail)
return Result.faild(msg="Can't Find Chart Detail Info!")
@ -197,36 +232,39 @@ async def get_editor_chart_info(param: dict = Body()):
@router.post("/v1/editor/chart/run", response_model=Result[ChartRunData])
async def editor_chart_run(run_param: dict = Body()):
logger.info(f"editor_chart_run:{run_param}")
db_name = run_param['db_name']
sql = run_param['sql']
chart_type = run_param['chart_type']
db_name = run_param["db_name"]
sql = run_param["sql"]
chart_type = run_param["chart_type"]
if not db_name and not sql:
return Result.faild("SQL run param error")
try:
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
colunms, sql_result = db_conn.query_ex(sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(colunms, sql_result, sql)
field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(
colunms, sql_result, sql
)
start_time = time.time() * 1000
# 计算执行耗时
end_time = time.time() * 1000
sql_run_data: SqlRunData = SqlRunData(result_info="",
run_cost=(end_time - start_time) / 1000,
colunms=colunms,
values=sql_result
)
return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values, chart_type = chart_type))
sql_run_data: SqlRunData = SqlRunData(
result_info="",
run_cost=(end_time - start_time) / 1000,
colunms=colunms,
values=sql_result,
)
return Result.succ(
ChartRunData(
sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type
)
)
except Exception as e:
sql_result = SqlRunData(result_info=str(e),
run_cost=0,
colunms=[],
values=[]
)
return Result.succ(ChartRunData(sql_data = sql_result,
chart_values=[],
chart_type = chart_type
))
sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[])
return Result.succ(
ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type)
)
@router.post("/v1/chart/editor/submit", response_model=Result[bool])
async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()):
@ -237,35 +275,53 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body())
dashboard_data_loader: DashboardDataLoader = DashboardDataLoader()
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name)
edit_round = max(history_messages, key=lambda x: x['chat_order'])
edit_round = max(history_messages, key=lambda x: x["chat_order"])
if edit_round:
try:
for element in edit_round["messages"]:
if element["type"] == "view":
view_data: dict = json.loads(element["data"]["content"]);
view_data: dict = json.loads(element["data"]["content"])
charts: List = view_data.get("charts")
find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_title, charts))[
0]
find_chart = list(
filter(
lambda x: x["chart_name"]
== chart_edit_context.chart_title,
charts,
)
)[0]
if chart_edit_context.new_chart_type:
find_chart['chart_type'] = chart_edit_context.new_chart_type
find_chart["chart_type"] = chart_edit_context.new_chart_type
if chart_edit_context.new_comment:
find_chart['chart_desc'] = chart_edit_context.new_comment
find_chart["chart_desc"] = chart_edit_context.new_comment
field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn,
chart_edit_context.new_sql)
find_chart['chart_sql'] = chart_edit_context.new_sql
find_chart['values'] = [value.dict() for value in chart_values]
find_chart['column_name'] = field_names
(
field_names,
chart_values,
) = dashboard_data_loader.get_chart_values_by_conn(
db_conn, chart_edit_context.new_sql
)
find_chart["chart_sql"] = chart_edit_context.new_sql
find_chart["values"] = [value.dict() for value in chart_values]
find_chart["column_name"] = field_names
element["data"]["content"] = json.dumps(view_data, ensure_ascii=False)
element["data"]["content"] = json.dumps(
view_data, ensure_ascii=False
)
if element["type"] == "ai":
ai_resp: dict = json.loads(element["data"]["content"])
edit_item = list(filter(lambda x: x['title'] == chart_edit_context.chart_title, ai_resp))[0]
edit_item = list(
filter(
lambda x: x["title"] == chart_edit_context.chart_title,
ai_resp,
)
)[0]
edit_item["sql"] = chart_edit_context.new_sql
edit_item["showcase"] = chart_edit_context.new_chart_type
edit_item["thoughts"] = chart_edit_context.new_comment
element["data"]["content"] = json.dumps(ai_resp, ensure_ascii=False)
element["data"]["content"] = json.dumps(
ai_resp, ensure_ascii=False
)
except Exception as e:
logger.error(f"edit chart exception!{str(e)}", e)
return Result.faild(msg=f"Edit chart exception!{str(e)}")

View File

@ -9,7 +9,7 @@ class DataNode(BaseModel):
type: str = ""
default_value: str = None
can_null: str = 'YES'
can_null: str = "YES"
comment: str = None
children: List = []

View File

@ -10,11 +10,13 @@ class DbField(BaseModel):
default_value: str = ""
comment: str = ""
class DbTable(BaseModel):
table_name: str
comment: str
colunm: List[DbField]
class ChatDbRounds(BaseModel):
round: int
db_name: str
@ -61,4 +63,3 @@ class ChatSqlEditContext(BaseModel):
new_sql: str
new_speak: str = ""

View File

@ -123,11 +123,7 @@ class BaseOutputParser(ABC):
ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "")
ai_response = (
ai_response.strip()
.replace("\\n", " ")
.replace("\n", " ")
)
ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
print("un_stream ai response:", ai_response)
return ai_response
else:
@ -209,9 +205,9 @@ class BaseOutputParser(ABC):
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\\n", " ")
.replace("\n", " ")
.replace("\\", " ")
.replace("\\n", " ")
.replace("\n", " ")
.replace("\\", " ")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output

View File

@ -12,7 +12,6 @@ class Scene:
is_inner: bool = False,
show_disable=False,
prepare_scene_code: str = None,
):
self.code = code
self.name = name
@ -22,38 +21,39 @@ class Scene:
self.show_disable = show_disable
self.prepare_scene_code = prepare_scene_code
class ChatScene(Enum):
ChatWithDbExecute = Scene(
code = "chat_with_db_execute",
name = "Chat Data",
describe = "Dialogue with your private data through natural language.",
param_types = ["DB Select"],
code="chat_with_db_execute",
name="Chat Data",
describe="Dialogue with your private data through natural language.",
param_types=["DB Select"],
)
ExcelLearning = Scene(
code = "excel_learning",
name = "Excel Learning",
describe = "Analyze and summarize your excel files.",
is_inner = True,
code="excel_learning",
name="Excel Learning",
describe="Analyze and summarize your excel files.",
is_inner=True,
)
ChatExcel = Scene(
code = "chat_excel",
name = "Chat Excel",
describe = "Dialogue with your excel, use natural language.",
code="chat_excel",
name="Chat Excel",
describe="Dialogue with your excel, use natural language.",
param_types=["File Select"],
prepare_scene_code="excel_learning"
prepare_scene_code="excel_learning",
)
ChatWithDbQA = Scene(
code = "chat_with_db_qa",
name = "Chat DB",
describe = "Have a Professional Conversation with Metadata.",
param_types = ["DB Select"],
code="chat_with_db_qa",
name="Chat DB",
describe="Have a Professional Conversation with Metadata.",
param_types=["DB Select"],
)
ChatExecution = Scene(
code = "chat_execution",
name = "Use Plugin",
describe = "Use tools through dialogue to accomplish your goals.",
param_types = ["Plugin Select"],
code="chat_execution",
name="Use Plugin",
describe="Use tools through dialogue to accomplish your goals.",
param_types=["Plugin Select"],
)
InnerChatDBSummary = Scene(
@ -78,7 +78,7 @@ class ChatScene(Enum):
@staticmethod
def of_mode(mode):
return [x for x in ChatScene._value_ if x.code == mode][0]
return [x for x in ChatScene._value_ if x.code == mode][0]
@staticmethod
def is_valid_mode(mode):
@ -100,4 +100,4 @@ class ChatScene(Enum):
return self._value_.show_disable
def is_inner(self):
return self._value_.is_inner
return self._value_.is_inner

View File

@ -60,11 +60,7 @@ class BaseChat(ABC):
arbitrary_types_allowed = True
def __init__(
self,
chat_mode,
chat_session_id,
current_user_input,
select_param: Any = None
self, chat_mode, chat_session_id, current_user_input, select_param: Any = None
):
self.chat_session_id = chat_session_id
self.chat_mode = chat_mode
@ -72,7 +68,6 @@ class BaseChat(ABC):
self.llm_model = CFG.LLM_MODEL
self.llm_echo = False
### load prompt template
# self.prompt_template: PromptTemplate = CFG.prompt_templates[
# self.chat_mode.value()
@ -182,6 +177,7 @@ class BaseChat(ABC):
return response
else:
from pilot.server.llmserver import worker
return worker.generate_stream_gate(payload)
except Exception as e:
print(traceback.format_exc())
@ -235,7 +231,9 @@ class BaseChat(ABC):
### llm speaker
speak_to_user = self.get_llm_speak(prompt_define_response)
view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result)
view_message = self.prompt_template.output_parser.parse_view_response(
speak_to_user, result
)
self.current_message.add_view_message(view_message)
except Exception as e:
print(traceback.format_exc())
@ -253,10 +251,8 @@ class BaseChat(ABC):
else:
return self.nostream_call()
def prepare(self):
pass
pass
def generate_llm_text(self) -> str:
warnings.warn("This method is deprecated - please use `generate_llm_messages`.")
@ -363,9 +359,7 @@ class BaseChat(ABC):
)
if len(self.history_message) > self.chat_retention_rounds:
for first_message in self.history_message[0]["messages"]:
if not first_message["type"] in [
ModelMessageRoleType.VIEW
]:
if not first_message["type"] in [ModelMessageRoleType.VIEW]:
message_type = first_message["type"]
message_content = first_message["data"]["content"]
history_text += (
@ -394,7 +388,9 @@ class BaseChat(ABC):
+ self.prompt_template.sep
)
history_messages.append(
ModelMessage(role=message_type, content=message_content)
ModelMessage(
role=message_type, content=message_content
)
)
else:

View File

@ -63,6 +63,7 @@ class AIMessage(BaseMessage):
class ViewMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
@ -73,6 +74,7 @@ class ViewMessage(BaseMessage):
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""

View File

@ -21,9 +21,15 @@ class ChatDashboard(BaseChat):
report_name: str
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = "", report_name:str="report"):
def __init__(
self,
chat_session_id,
user_input,
select_param: str = "",
report_name: str = "report",
):
""" """
self.db_name=select_param
self.db_name = select_param
super().__init__(
chat_mode=ChatScene.ChatDashboard,
chat_session_id=chat_session_id,
@ -80,7 +86,9 @@ class ChatDashboard(BaseChat):
dashboard_data_loader = DashboardDataLoader()
for chart_item in prompt_response:
try:
field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.database, chart_item.sql)
field_names, values = dashboard_data_loader.get_chart_values_by_conn(
self.database, chart_item.sql
)
chart_datas.append(
ChartData(
chart_uid=str(uuid.uuid1()),
@ -101,5 +109,3 @@ class ChatDashboard(BaseChat):
template_introduce=None,
charts=chart_datas,
)

View File

@ -11,15 +11,14 @@ logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log")
class DashboardDataLoader:
def get_sql_value(self, db_conn, chart_sql: str):
return db_conn.query_ex(chart_sql)
return db_conn.query_ex(chart_sql)
def get_chart_values_by_conn(self, db_conn, chart_sql: str) :
field_names, datas = db_conn.query_ex(chart_sql)
return self.get_chart_values_by_data(field_names, datas, chart_sql)
def get_chart_values_by_conn(self, db_conn, chart_sql: str):
field_names, datas = db_conn.query_ex(chart_sql)
return self.get_chart_values_by_data(field_names, datas, chart_sql)
def get_chart_values_by_data(self, field_names, datas, chart_sql: str) :
def get_chart_values_by_data(self, field_names, datas, chart_sql: str):
logger.info(f"get_chart_values_by_conn:{chart_sql}")
try:
values: List[ValueItem] = []
@ -57,7 +56,7 @@ class DashboardDataLoader:
logger.debug("Prepare Chart Data Faild!" + str(e))
raise ValueError("Prepare Chart Data Faild!")
def get_chart_values_by_db(self, db_name: str, chart_sql: str) :
def get_chart_values_by_db(self, db_name: str, chart_sql: str):
logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}")
db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name)
return self.get_chart_values_by_conn(db_conn, chart_sql)

View File

@ -27,6 +27,7 @@ CFG = Config()
class ChatExcel(BaseChat):
chat_scene: str = ChatScene.ChatExcel.value()
chat_retention_rounds = 1
def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatExcel
@ -34,9 +35,11 @@ class ChatExcel(BaseChat):
if has_path(select_param):
self.excel_reader = ExcelReader(select_param)
else:
self.excel_reader = ExcelReader(os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
))
self.excel_reader = ExcelReader(
os.path.join(
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param
)
)
super().__init__(
chat_mode=chat_mode,
@ -70,12 +73,9 @@ class ChatExcel(BaseChat):
]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self):
input_values = {
"user_input": self.current_user_input,
"user_input": self.current_user_input,
"table_name": self.excel_reader.table_name,
"disply_type": self._generate_numbered_list(),
}
@ -87,23 +87,21 @@ class ChatExcel(BaseChat):
return None
chat_param = {
"chat_session_id": self.chat_session_id,
"user_input": "[" + self.excel_reader.excel_file_name +"]" + " Analysis",
"user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analysis",
"parent_mode": self.chat_mode,
"select_param":self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader
"select_param": self.excel_reader.excel_file_name,
"excel_reader": self.excel_reader,
}
learn_chat = ExcelLearning(**chat_param)
result = learn_chat.nostream_call()
return result
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
# colunms, datas = self.excel_reader.run(prompt_response.sql)
param= {
param = {
"speak": prompt_response.thoughts,
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql)
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
}
return CFG.command_disply.call(prompt_response.display, **param)

View File

@ -2,7 +2,9 @@ import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import ChatExcelOutputParser
from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import (
ChatExcelOutputParser,
)
from pilot.common.schema import SeparatorStyle
CFG = Config()
@ -39,9 +41,9 @@ SQL中需要使用的表名是: {table_name}
"""
RESPONSE_FORMAT_SIMPLE = {
"sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis",
"display": "display type name"
"sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis",
"display": "display type name",
}
_DEFAULT_TEMPLATE = (
@ -67,9 +69,8 @@ prompt = PromptTemplate(
output_parser=ChatExcelOutputParser(
sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT
),
need_historical_messages = True,
need_historical_messages=True,
# example_selector=sql_data_example,
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -23,7 +23,14 @@ CFG = Config()
class ExcelLearning(BaseChat):
chat_scene: str = ChatScene.ExcelLearning.value()
def __init__(self, chat_session_id, user_input, parent_mode: Any=None, select_param:str=None, excel_reader:Any=None):
def __init__(
self,
chat_session_id,
user_input,
parent_mode: Any = None,
select_param: str = None,
excel_reader: Any = None,
):
chat_mode = ChatScene.ExcelLearning
""" """
self.excel_file_path = select_param
@ -31,20 +38,19 @@ class ExcelLearning(BaseChat):
super().__init__(
chat_mode=chat_mode,
chat_session_id=chat_session_id,
current_user_input = user_input,
current_user_input=user_input,
select_param=select_param,
)
if parent_mode:
self.current_message.chat_mode = parent_mode.value()
def generate_input_values(self):
colunms, datas = self.excel_reader.get_sample_data()
datas.insert(0, colunms)
input_values = {
"data_example": json.dumps(self.excel_reader.get_sample_data(), cls=DateTimeEncoder),
"data_example": json.dumps(
self.excel_reader.get_sample_data(), cls=DateTimeEncoder
),
}
return input_values

View File

@ -37,23 +37,23 @@ class LearningExcelOutputParser(BaseOutputParser):
plans = response[key]
return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
html_title = f"### **数据简介**\n{data.desciption} "
html_colunms = f"### **数据结构**\n"
column_index = 0
for item in data.clounms:
column_index +=1
column_index += 1
keys = item.keys()
for key in keys:
html_colunms = html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
html_colunms = (
html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n"
)
html_plans = f"### **分析计划**\n"
index = 0
for item in data.plans:
index +=1
index += 1
html_plans = html_plans + f"{item} \n"
html = f"""{html_title}\n{html_colunms}\n{html_plans}"""
return html

View File

@ -2,7 +2,9 @@ import json
from pilot.prompts.prompt_new import PromptTemplate
from pilot.configs.config import Config
from pilot.scene.base import ChatScene
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import LearningExcelOutputParser
from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import (
LearningExcelOutputParser,
)
from pilot.common.schema import SeparatorStyle
CFG = Config()
@ -29,7 +31,7 @@ _DEFAULT_TEMPLATE_ZH = """
{response}
"""
RESPONSE_FORMAT_SIMPLE = {
RESPONSE_FORMAT_SIMPLE = {
"DataAnalysis": "数据内容分析总结",
"ColumnAnalysis": [{"column name1": "字段1介绍专业术语解释(请尽量简单明了)"}],
"AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
@ -63,5 +65,3 @@ prompt = PromptTemplate(
temperature=PROMPT_TEMPERATURE,
)
CFG.prompt_template_registry.register(prompt, is_default=True)

View File

@ -7,7 +7,7 @@ import seaborn as sns
import matplotlib.pyplot as plt
import time
from fsspec import filesystem
import spatial
import spatial
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
@ -18,22 +18,35 @@ if __name__ == "__main__":
excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx")
# colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear")
colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """)
df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;")
colunms, datas = excel_reader.run(
""" SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """
)
df = excel_reader.get_df_by_sql_ex(
"SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;"
)
columns = df.columns.tolist()
plt.rcParams["font.family"] = ["sans-serif"]
rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False}
sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"})
sns.set_style(rc={"font.sans-serif": "Microsoft Yahei"})
sns.set(context="notebook", style="ticks", color_codes=True, rc=rc)
sns.set_palette("Set3") # 设置颜色主题
# fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
plt.subplots_adjust(top=0.9)
ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%')
ax = df.plot(
kind="pie",
y=columns[1],
ax=ax,
labels=df[columns[0]].values,
startangle=90,
autopct="%1.1f%%",
)
# 手动设置 labels 的位置和大小
ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10)
plt.axis('equal') # 使饼图为正圆形
ax.legend(
loc="center left", bbox_to_anchor=(-1, 0.5, 0, 0), labels=None, fontsize=10
)
plt.axis("equal") # 使饼图为正圆形
plt.show()
#
#

View File

@ -7,11 +7,13 @@ import numpy as np
from pilot.common.pd_utils import csv_colunm_foramt
def excel_colunm_format(old_name:str)->str:
def excel_colunm_format(old_name: str) -> str:
new_column = old_name.strip()
new_column = new_column.replace(" ", "_")
return new_column
def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "")
parsed = sqlparse.parse(sql)
@ -20,44 +22,51 @@ def add_quotes(sql, column_names=[]):
deep_quotes(token, column_names)
return str(parsed[0])
def deep_quotes(token, column_names=[]):
if hasattr(token, "tokens") :
if hasattr(token, "tokens"):
for token_child in token.tokens:
deep_quotes(token_child, column_names)
else:
if token.ttype == sqlparse.tokens.Name:
if len(column_names) >0:
if len(column_names) > 0:
if token.value in column_names:
token.value = f'"{token.value.replace("`", "")}"'
else:
token.value = f'"{token.value.replace("`", "")}"'
def is_chinese(string):
# 使用正则表达式匹配中文字符
pattern = re.compile(r'[一-龥]')
pattern = re.compile(r"[一-龥]")
match = re.search(pattern, string)
return match is not None
class ExcelReader:
def __init__(self, file_path):
file_name = os.path.basename(file_path)
file_name_without_extension = os.path.splitext(file_name)[0]
self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1]
# read excel file
if file_path.endswith('.xlsx') or file_path.endswith('.xls'):
if file_path.endswith(".xlsx") or file_path.endswith(".xls"):
df_tmp = pd.read_excel(file_path)
self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])})
elif file_path.endswith('.csv'):
self.df = pd.read_excel(
file_path,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
elif file_path.endswith(".csv"):
df_tmp = pd.read_csv(file_path)
self.df = pd.read_csv(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])})
self.df = pd.read_csv(
file_path,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else:
raise ValueError("Unsupported file format.")
self.df.replace('', np.nan, inplace=True)
self.df.replace("", np.nan, inplace=True)
self.columns_map = {}
for column_name in df_tmp.columns:
self.columns_map.update({column_name: excel_colunm_format(column_name)})
@ -66,18 +75,17 @@ class ExcelReader:
except Exception as e:
print("transfor column error" + column_name)
self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_'))
self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_"))
# connect DuckDB
self.db = duckdb.connect(database=':memory:', read_only=False)
self.db = duckdb.connect(database=":memory:", read_only=False)
self.table_name = file_name_without_extension
# write data in duckdb
self.db.register(self.table_name, self.df)
def run(self, sql):
if f'"{self.table_name}"' not in sql:
if f'"{self.table_name}"' not in sql:
sql = sql.replace(self.table_name, f'"{self.table_name}"')
sql = add_quotes(sql, self.columns_map.values())
print(f"excute sql:{sql}")
@ -92,5 +100,4 @@ class ExcelReader:
return pd.DataFrame(values, columns=colunms)
def get_sample_data(self):
return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;')
return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;")

View File

@ -21,7 +21,7 @@ class ChatWithDbAutoExecute(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = ""):
def __init__(self, chat_session_id, user_input, select_param: str = ""):
chat_mode = ChatScene.ChatWithDbExecute
self.db_name = select_param
""" """
@ -47,7 +47,9 @@ class ChatWithDbAutoExecute(BaseChat):
client = DBSummaryClient()
try:
table_infos = client.get_db_summary(
dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE
dbname=self.db_name,
query=self.current_user_input,
topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE,
)
except Exception as e:
print("db summary find error!" + str(e))
@ -65,4 +67,4 @@ class ChatWithDbAutoExecute(BaseChat):
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
return self.database.run( prompt_response.sql)
return self.database.run(prompt_response.sql)

View File

@ -8,6 +8,7 @@ from pilot.out_parser.base import BaseOutputParser, T
from pilot.configs.model_config import LOGDIR
from pilot.configs.config import Config
from pilot.scene.chat_db.data_loader import DbDataLoader
CFG = Config()
@ -49,6 +50,4 @@ class DbChatOutputParser(BaseOutputParser):
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text
else:
return data_loader.get_table_view_by_conn(data, speak)
return data_loader.get_table_view_by_conn(data, speak)

View File

@ -1,8 +1,7 @@
import pandas as pd
class DbDataLoader:
def get_table_view_by_conn(self, data, speak):
### tool out data to table view
if len(data) <= 1:
@ -12,4 +11,4 @@ class DbDataLoader:
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
return view_text
return view_text

View File

@ -19,7 +19,7 @@ class ChatWithDbQA(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param:str = ""):
def __init__(self, chat_session_id, user_input, select_param: str = ""):
""" """
self.db_name = select_param
super().__init__(

View File

@ -20,12 +20,7 @@ class ChatWithPlugin(BaseChat):
plugins_prompt_generator: PluginPromptGenerator
select_plugin: str = None
def __init__(
self,
chat_session_id,
user_input,
select_param: str = None
):
def __init__(self, chat_session_id, user_input, select_param: str = None):
self.plugin_selector = select_param
super().__init__(
chat_mode=ChatScene.ChatExecution,

View File

@ -30,7 +30,7 @@ class ChatKnowledge(BaseChat):
"""Number of results to return from the query"""
def __init__(self, chat_session_id, user_input, select_param: str = None):
def __init__(self, chat_session_id, user_input, select_param: str = None):
""" """
self.knowledge_space = select_param
super().__init__(

View File

@ -30,7 +30,6 @@ class ChatNormal(BaseChat):
input_values = {"input": self.current_user_input}
return input_values
@property
def chat_type(self) -> str:
return ChatScene.ChatNormal.value

View File

@ -117,12 +117,10 @@ def _conversation_to_dic(once: OnceConversation) -> dict:
"tokens": once.tokens if once.tokens else 0,
"messages": messages_to_dict(once.messages),
"param_type": once.param_type,
"param_value": once.param_value
"param_value": once.param_value,
}
def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
return [_conversation_to_dic(m) for m in conversations]

View File

@ -64,16 +64,12 @@ def server_init(args):
cfg.command_registry = command_registry
command_disply_commands = [
"pilot.commands.disply_type.show_chart_gen",
"pilot.commands.disply_type.show_table_gen",
"pilot.commands.disply_type.show_text_gen",
]
command_disply_registry = CommandRegistry()
command_disply_registry = CommandRegistry()
for command in command_disply_commands:
command_disply_registry.import_commands(command)
cfg.command_disply = command_disply_registry

View File

@ -81,7 +81,9 @@ app.include_router(knowledge_router)
# app.include_router(api_editor_route_v1)
os.makedirs(static_message_img_path, exist_ok=True)
app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2")
app.mount(
"/images", StaticFiles(directory=static_message_img_path, html=True), name="static2"
)
app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static"))
app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static")

View File

@ -19,7 +19,6 @@ from pilot.utils import build_logger
logger = build_logger("db_summary", LOGDIR + "db_summary.log")
CFG = Config()
chat_factory = ChatFactory()
@ -145,7 +144,9 @@ class DBSummaryClient:
try:
self.db_summary_embedding(item["db_name"], item["db_type"])
except Exception as e:
logger.warn(f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e)
logger.warn(
f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e
)
def init_db_profile(self, db_summary_client, dbname, embeddings):
profile_store_config = {

View File

@ -75,7 +75,6 @@ def build_logger(logger_name, logger_filename):
item.addHandler(handler)
logging.basicConfig(level=logging.INFO)
# Get logger
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)