From 71b9cd14a6be8dfac52b29a9242566d1c9893808 Mon Sep 17 00:00:00 2001 From: aries_ckt <916701291@qq.com> Date: Tue, 29 Aug 2023 19:57:33 +0800 Subject: [PATCH] style:fmt --- pilot/commands/disply_type/show_chart_gen.py | 112 ++++++--- pilot/commands/disply_type/show_table_gen.py | 8 +- pilot/commands/disply_type/show_text_gen.py | 9 +- pilot/common/pd_utils.py | 4 +- pilot/connections/rdbms/base.py | 15 +- pilot/connections/rdbms/tests/mange_t.py | 4 +- pilot/memory/chat_history/duckdb_history.py | 4 +- pilot/openapi/api_v1/api_v1.py | 14 +- pilot/openapi/api_v1/editor/api_editor_v1.py | 212 +++++++++++------- pilot/openapi/api_v1/editor/sql_editor.py | 2 +- pilot/openapi/editor_view_model.py | 3 +- pilot/out_parser/base.py | 12 +- pilot/scene/base.py | 46 ++-- pilot/scene/base_chat.py | 24 +- pilot/scene/base_message.py | 2 + pilot/scene/chat_dashboard/chat.py | 16 +- pilot/scene/chat_dashboard/data_loader.py | 13 +- .../chat_excel/excel_analyze/chat.py | 26 +-- .../chat_excel/excel_analyze/prompt.py | 13 +- .../chat_excel/excel_learning/chat.py | 18 +- .../chat_excel/excel_learning/out_parser.py | 10 +- .../chat_excel/excel_learning/prompt.py | 8 +- .../chat_excel/excel_learning/test.py | 27 ++- .../chat_data/chat_excel/excel_reader.py | 41 ++-- pilot/scene/chat_db/auto_execute/chat.py | 8 +- .../scene/chat_db/auto_execute/out_parser.py | 5 +- pilot/scene/chat_db/data_loader.py | 5 +- pilot/scene/chat_db/professional_qa/chat.py | 2 +- pilot/scene/chat_execution/chat.py | 7 +- pilot/scene/chat_knowledge/v1/chat.py | 2 +- pilot/scene/chat_normal/chat.py | 1 - pilot/scene/message.py | 4 +- pilot/server/base.py | 6 +- pilot/server/dbgpt_server.py | 4 +- pilot/summary/db_summary_client.py | 5 +- pilot/utils.py | 1 - 36 files changed, 413 insertions(+), 280 deletions(-) diff --git a/pilot/commands/disply_type/show_chart_gen.py b/pilot/commands/disply_type/show_chart_gen.py index 1d3514ced..ec3559e27 100644 --- a/pilot/commands/disply_type/show_chart_gen.py +++ b/pilot/commands/disply_type/show_chart_gen.py @@ -20,8 +20,17 @@ CFG = Config() logger = build_logger("show_chart_gen", LOGDIR + "show_chart_gen.log") static_message_img_path = os.path.join(os.getcwd(), "message/img") + def zh_font_set(): - font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] fm = FontManager() mat_fonts = set(f.name for f in fm.ttflist) can_use_fonts = [] @@ -29,11 +38,14 @@ def zh_font_set(): if font_name in mat_fonts: can_use_fonts.append(font_name) if len(can_use_fonts) > 0: - plt.rcParams['font.sans-serif'] = can_use_fonts + plt.rcParams["font.sans-serif"] = can_use_fonts -@command("response_line_chart", "Line chart display, used to display comparative trend analysis data", - '"speak": "", "df":""') +@command( + "response_line_chart", + "Line chart display, used to display comparative trend analysis data", + '"speak": "", "df":""', +) def response_line_chart(speak: str, df: DataFrame) -> str: logger.info(f"response_line_chart:{speak},") @@ -44,7 +56,15 @@ def response_line_chart(speak: str, df: DataFrame) -> str: # set font # zh_font_set() - font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] fm = FontManager() mat_fonts = set(f.name for f in fm.ttflist) can_use_fonts = [] @@ -52,31 +72,34 @@ def response_line_chart(speak: str, df: DataFrame) -> str: if font_name in mat_fonts: can_use_fonts.append(font_name) if len(can_use_fonts) > 0: - plt.rcParams['font.sans-serif'] = can_use_fonts + plt.rcParams["font.sans-serif"] = can_use_fonts - rc = {'font.sans-serif': can_use_fonts} - plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 + rc = {"font.sans-serif": can_use_fonts} + plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 sns.set_palette("Set3") # 设置颜色主题 sns.set_style("dark") sns.color_palette("hls", 10) - sns.hls_palette(8, l=.5, s=.7) - sns.set(context='notebook', style='ticks', rc=rc) + sns.hls_palette(8, l=0.5, s=0.7) + sns.set(context="notebook", style="ticks", rc=rc) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) sns.lineplot(df, x=columns[0], y=columns[1], ax=ax) chart_name = "line_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name - plt.savefig(chart_path, bbox_inches='tight', dpi=100) + plt.savefig(chart_path, bbox_inches="tight", dpi=100) html_img = f"""
{speak}
""" return html_img -@command("response_bar_chart", "Histogram, suitable for comparative analysis of multiple target values", - '"speak": "", "df":""') +@command( + "response_bar_chart", + "Histogram, suitable for comparative analysis of multiple target values", + '"speak": "", "df":""', +) def response_bar_chart(speak: str, df: DataFrame) -> str: logger.info(f"response_bar_chart:{speak},") columns = df.columns.tolist() @@ -85,7 +108,15 @@ def response_bar_chart(speak: str, df: DataFrame) -> str: # set font # zh_font_set() - font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] fm = FontManager() mat_fonts = set(f.name for f in fm.ttflist) can_use_fonts = [] @@ -93,29 +124,32 @@ def response_bar_chart(speak: str, df: DataFrame) -> str: if font_name in mat_fonts: can_use_fonts.append(font_name) if len(can_use_fonts) > 0: - plt.rcParams['font.sans-serif'] = can_use_fonts + plt.rcParams["font.sans-serif"] = can_use_fonts - rc = {'font.sans-serif': can_use_fonts} - plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 + rc = {"font.sans-serif": can_use_fonts} + plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题 sns.set_palette("Set3") # 设置颜色主题 sns.set_style("dark") sns.color_palette("hls", 10) - sns.hls_palette(8, l=.5, s=.7) - sns.set(context='notebook', style='ticks', rc=rc) + sns.hls_palette(8, l=0.5, s=0.7) + sns.set(context="notebook", style="ticks", rc=rc) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) sns.barplot(df, x=df[columns[0]], y=df[columns[1]], ax=ax) chart_name = "pie_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name - plt.savefig(chart_path, bbox_inches='tight', dpi=100) + plt.savefig(chart_path, bbox_inches="tight", dpi=100) html_img = f"""
{speak}
""" return html_img -@command("response_pie_chart", "Pie chart, suitable for scenarios such as proportion and distribution statistics", - '"speak": "", "df":""') +@command( + "response_pie_chart", + "Pie chart, suitable for scenarios such as proportion and distribution statistics", + '"speak": "", "df":""', +) def response_pie_chart(speak: str, df: DataFrame) -> str: logger.info(f"response_pie_chart:{speak},") columns = df.columns.tolist() @@ -123,7 +157,15 @@ def response_pie_chart(speak: str, df: DataFrame) -> str: raise ValueError("No Data!") # set font # zh_font_set() - font_names = ['Heiti TC', 'Songti SC', 'STHeiti Light', 'Microsoft YaHei', 'SimSun', 'SimHei', 'KaiTi'] + font_names = [ + "Heiti TC", + "Songti SC", + "STHeiti Light", + "Microsoft YaHei", + "SimSun", + "SimHei", + "KaiTi", + ] fm = FontManager() mat_fonts = set(f.name for f in fm.ttflist) can_use_fonts = [] @@ -131,23 +173,35 @@ def response_pie_chart(speak: str, df: DataFrame) -> str: if font_name in mat_fonts: can_use_fonts.append(font_name) if len(can_use_fonts) > 0: - plt.rcParams['font.sans-serif'] = can_use_fonts - plt.rcParams['axes.unicode_minus'] = False # 解决无法显示符号的问题 + plt.rcParams["font.sans-serif"] = can_use_fonts + plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题 sns.set_palette("Set3") # 设置颜色主题 # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) - ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') + ax = df.plot( + kind="pie", + y=columns[1], + ax=ax, + labels=df[columns[0]].values, + startangle=90, + autopct="%1.1f%%", + ) # 手动设置 labels 的位置和大小 - ax.legend(loc='upper right', bbox_to_anchor=(0, 0, 1, 1), labels=df[columns[0]].values, fontsize=10) + ax.legend( + loc="upper right", + bbox_to_anchor=(0, 0, 1, 1), + labels=df[columns[0]].values, + fontsize=10, + ) - plt.axis('equal') # 使饼图为正圆形 + plt.axis("equal") # 使饼图为正圆形 # plt.title(columns[0]) chart_name = "pie_" + str(uuid.uuid1()) + ".png" chart_path = static_message_img_path + "/" + chart_name - plt.savefig(chart_path, bbox_inches='tight', dpi=100) + plt.savefig(chart_path, bbox_inches="tight", dpi=100) html_img = f"""
{speak.replace("`", '"')}
""" diff --git a/pilot/commands/disply_type/show_table_gen.py b/pilot/commands/disply_type/show_table_gen.py index 0e3cf8dad..d67e39eb9 100644 --- a/pilot/commands/disply_type/show_table_gen.py +++ b/pilot/commands/disply_type/show_table_gen.py @@ -11,8 +11,12 @@ CFG = Config() logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") -@command("response_table", "Table display, suitable for display with many display columns or non-numeric columns", '"speak": "", "df":""') -def response_table(speak: str, df: DataFrame) -> str: +@command( + "response_table", + "Table display, suitable for display with many display columns or non-numeric columns", + '"speak": "", "df":""', +) +def response_table(speak: str, df: DataFrame) -> str: logger.info(f"response_table:{speak}") html_table = df.to_html(index=False, escape=False, sparsify=False) table_str = "".join(html_table.split()) diff --git a/pilot/commands/disply_type/show_text_gen.py b/pilot/commands/disply_type/show_text_gen.py index 16932ff1c..16d9fbe91 100644 --- a/pilot/commands/disply_type/show_text_gen.py +++ b/pilot/commands/disply_type/show_text_gen.py @@ -10,9 +10,12 @@ CFG = Config() logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log") -@command("response_data_text", "Text display, the default display method, suitable for single-line or simple content display", - '"speak": "", "df":""') -def response_data_text(speak: str, df: DataFrame) -> str: +@command( + "response_data_text", + "Text display, the default display method, suitable for single-line or simple content display", + '"speak": "", "df":""', +) +def response_data_text(speak: str, df: DataFrame) -> str: logger.info(f"response_data_text:{speak}") data = df.values diff --git a/pilot/common/pd_utils.py b/pilot/common/pd_utils.py index 96729e124..9c46055ac 100644 --- a/pilot/common/pd_utils.py +++ b/pilot/common/pd_utils.py @@ -1,6 +1,6 @@ def csv_colunm_foramt(val): if str(val).find("$") >= 0: - return float(val.replace('$', '').replace(',', '')) + return float(val.replace("$", "").replace(",", "")) if str(val).find("¥") >= 0: - return float(val.replace('¥', '').replace(',', '')) + return float(val.replace("¥", "").replace(",", "")) return val diff --git a/pilot/connections/rdbms/base.py b/pilot/connections/rdbms/base.py index 51d70d386..d243b34eb 100644 --- a/pilot/connections/rdbms/base.py +++ b/pilot/connections/rdbms/base.py @@ -262,7 +262,7 @@ class RDBMSDatabase(BaseConnect): """Format the error message""" return f"Error: {e}" - def __write(self, write_sql): + def __write(self, write_sql): print(f"Write[{write_sql}]") db_cache = self._engine.url.database result = self.session.execute(text(write_sql)) @@ -272,7 +272,7 @@ class RDBMSDatabase(BaseConnect): print(f"SQL[{write_sql}], result:{result.rowcount}") return result.rowcount - def __query(self,query, fetch: str = "all"): + def __query(self, query, fetch: str = "all"): """ only for query Args: @@ -325,6 +325,7 @@ class RDBMSDatabase(BaseConnect): result = list(result) return field_names, result return [] + def run(self, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results.""" print("SQL:" + command) @@ -333,12 +334,12 @@ class RDBMSDatabase(BaseConnect): parsed, ttype, sql_type, table_name = self.__sql_parse(command) if ttype == sqlparse.tokens.DML: if sql_type == "SELECT": - return self.__query( command, fetch) + return self.__query(command, fetch) else: - self.__write( command) + self.__write(command) select_sql = self.convert_sql_write_to_select(command) print(f"write result query:{select_sql}") - return self.__query( select_sql) + return self.__query(select_sql) else: print(f"DDL execution determines whether to enable through configuration ") @@ -351,10 +352,10 @@ class RDBMSDatabase(BaseConnect): result.insert(0, field_names) print("DDL Result:" + str(result)) if not result: - return self.__query( f"SHOW COLUMNS FROM {table_name}") + return self.__query(f"SHOW COLUMNS FROM {table_name}") return result else: - return self.__query( f"SHOW COLUMNS FROM {table_name}") + return self.__query(f"SHOW COLUMNS FROM {table_name}") def run_no_throw(self, session, command: str, fetch: str = "all") -> List: """Execute a SQL command and return a string representing the results. diff --git a/pilot/connections/rdbms/tests/mange_t.py b/pilot/connections/rdbms/tests/mange_t.py index 820e321de..5140341d6 100644 --- a/pilot/connections/rdbms/tests/mange_t.py +++ b/pilot/connections/rdbms/tests/mange_t.py @@ -2,6 +2,6 @@ from pilot.configs.config import Config from pilot.connections.manages.connection_manager import ConnectManager if __name__ == "__main__": - mange= ConnectManager() + mange = ConnectManager() types = mange.get_all_completed_types() - print(str(types)) \ No newline at end of file + print(str(types)) diff --git a/pilot/memory/chat_history/duckdb_history.py b/pilot/memory/chat_history/duckdb_history.py index 827107515..09f4612df 100644 --- a/pilot/memory/chat_history/duckdb_history.py +++ b/pilot/memory/chat_history/duckdb_history.py @@ -94,8 +94,7 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): cursor.commit() self.connect.commit() - - def update(self, messages:List[OnceConversation]) -> None: + def update(self, messages: List[OnceConversation]) -> None: cursor = self.connect.cursor() cursor.execute( "UPDATE chat_history set messages=? where conv_uid=?", @@ -161,7 +160,6 @@ class DuckdbHistoryMemory(BaseChatHistoryMemory): return {} - def get_messages(self) -> List[OnceConversation]: cursor = self.connect.cursor() cursor.execute( diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index 33e226468..7b7f0d4dd 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -110,7 +110,6 @@ async def db_connect_delete(db_name: str = None): @router.get("/v1/chat/db/support/type", response_model=Result[DbTypeInfo]) async def db_support_types(): - support_types = CFG.LOCAL_DB_MANAGE.get_all_completed_types() db_type_infos = [] for type in support_types: @@ -130,7 +129,7 @@ async def dialogue_list(user_id: str = None): chat_mode = item.get("chat_mode") messages = json.loads(item.get("messages")) - last_round = max(messages, key=lambda x: x['chat_order']) + last_round = max(messages, key=lambda x: x["chat_order"]) if "param_value" in last_round: select_param = last_round["param_value"] else: @@ -139,7 +138,7 @@ async def dialogue_list(user_id: str = None): conv_uid=conv_uid, user_input=summary, chat_mode=chat_mode, - select_param=select_param + select_param=select_param, ) dialogues.append(conv_vo) @@ -213,7 +212,9 @@ async def params_load(conv_uid: str, chat_mode: str, doc_file: UploadFile = File ), ) ## chat prepare - dialogue = ConversationVo(conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename) + dialogue = ConversationVo( + conv_uid=conv_uid, chat_mode=chat_mode, select_param=doc_file.filename + ) chat: BaseChat = get_chat_instance(dialogue) resp = chat.prepare() @@ -229,7 +230,8 @@ async def dialogue_delete(con_uid: str): history_mem.delete() return Result.succ(None) -def get_hist_messages(conv_uid:str): + +def get_hist_messages(conv_uid: str): message_vos: List[MessageVo] = [] history_mem = DuckdbHistoryMemory(conv_uid) history_messages: List[OnceConversation] = history_mem.get_messages() @@ -264,7 +266,7 @@ def get_chat_instance(dialogue: ConversationVo = Body()) -> BaseChat: chat_param = { "chat_session_id": dialogue.conv_uid, "user_input": dialogue.user_input, - "select_param": dialogue.select_param + "select_param": dialogue.select_param, } chat: BaseChat = CHAT_FACTORY.get_implementation(dialogue.chat_mode, **chat_param) return chat diff --git a/pilot/openapi/api_v1/editor/api_editor_v1.py b/pilot/openapi/api_v1/editor/api_editor_v1.py index fa9b6b238..896c91877 100644 --- a/pilot/openapi/api_v1/editor/api_editor_v1.py +++ b/pilot/openapi/api_v1/editor/api_editor_v1.py @@ -22,7 +22,7 @@ from pilot.openapi.editor_view_model import ( ChartDetail, ChatChartEditContext, ChatSqlEditContext, - DbTable + DbTable, ) from pilot.openapi.api_v1.editor.sql_editor import DataNode, ChartRunData, SqlRunData @@ -38,7 +38,9 @@ logger = build_logger("api_editor_v1", LOGDIR + "api_editor_v1.log") @router.get("/v1/editor/db/tables", response_model=Result[DbTable]) -async def get_editor_tables(db_name: str, page_index: int, page_size: int, search_str: str = ""): +async def get_editor_tables( + db_name: str, page_index: int, page_size: int, search_str: str = "" +): logger.info(f"get_editor_tables:{db_name},{page_index},{page_size},{search_str}") db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) tables = db_conn.get_table_names() @@ -49,8 +51,15 @@ async def get_editor_tables(db_name: str, page_index: int, page_size: int, searc fields = db_conn.get_fields(table) for field in fields: table_node.children.append( - DataNode(title=field[0], key=field[0], type=field[1], default_value=field[2], can_null=field[3], - comment=field[-1])) + DataNode( + title=field[0], + key=field[0], + type=field[1], + default_value=field[2], + can_null=field[3], + comment=field[-1], + ) + ) return Result.succ(db_node) @@ -68,8 +77,11 @@ async def get_editor_sql_rounds(con_uid: str): if element["type"] == "human": round_name = element["data"]["content"] if once.get("param_value"): - round: ChatDbRounds = ChatDbRounds(round=once["chat_order"], db_name=once["param_value"], - round_name=round_name) + round: ChatDbRounds = ChatDbRounds( + round=once["chat_order"], + db_name=once["param_value"], + round_name=round_name, + ) result.append(round) return Result.succ(result) @@ -84,8 +96,14 @@ async def get_editor_sql(con_uid: str, round: int): if int(once["chat_order"]) == round: for element in once["messages"]: if element["type"] == "ai": - logger.info(f'history ai json resp:{element["data"]["content"]}') - context = element["data"]["content"].replace("\\n", " ").replace("\n", " ") + logger.info( + f'history ai json resp:{element["data"]["content"]}' + ) + context = ( + element["data"]["content"] + .replace("\\n", " ") + .replace("\n", " ") + ) return Result.succ(json.loads(context)) return Result.faild(msg="not have sql!") @@ -93,8 +111,8 @@ async def get_editor_sql(con_uid: str, round: int): @router.post("/v1/editor/sql/run", response_model=Result[SqlRunData]) async def editor_sql_run(run_param: dict = Body()): logger.info(f"editor_sql_run:{run_param}") - db_name = run_param['db_name'] - sql = run_param['sql'] + db_name = run_param["db_name"] + sql = run_param["sql"] if not db_name and not sql: return Result.faild("SQL run param error!") conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) @@ -104,18 +122,17 @@ async def editor_sql_run(run_param: dict = Body()): colunms, sql_result = conn.query_ex(sql) # 计算执行耗时 end_time = time.time() * 1000 - sql_run_data: SqlRunData = SqlRunData(result_info="", - run_cost=(end_time - start_time) / 1000, - colunms=colunms, - values=sql_result - ) + sql_run_data: SqlRunData = SqlRunData( + result_info="", + run_cost=(end_time - start_time) / 1000, + colunms=colunms, + values=sql_result, + ) return Result.succ(sql_run_data) except Exception as e: - return Result.succ(SqlRunData(result_info=str(e), - run_cost=0, - colunms=[], - values=[] - )) + return Result.succ( + SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[]) + ) @router.post("/v1/sql/editor/submit") @@ -126,18 +143,24 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): if history_messages: conn = CFG.LOCAL_DB_MANAGE.get_connect(sql_edit_context.db_name) - edit_round = list(filter(lambda x: x['chat_order'] == sql_edit_context.conv_round, history_messages))[0] + edit_round = list( + filter( + lambda x: x["chat_order"] == sql_edit_context.conv_round, + history_messages, + ) + )[0] if edit_round: for element in edit_round["messages"]: if element["type"] == "ai": db_resp = json.loads(element["data"]["content"]) - db_resp['thoughts'] = sql_edit_context.new_speak - db_resp['sql'] = sql_edit_context.new_sql + db_resp["thoughts"] = sql_edit_context.new_speak + db_resp["sql"] = sql_edit_context.new_sql element["data"]["content"] = json.dumps(db_resp) if element["type"] == "view": data_loader = DbDataLoader() - element["data"]["content"] = data_loader.get_table_view_by_conn(conn.run(sql_edit_context.new_sql), - sql_edit_context.new_speak) + element["data"]["content"] = data_loader.get_table_view_by_conn( + conn.run(sql_edit_context.new_sql), sql_edit_context.new_speak + ) history_mem.update(history_messages) return Result.succ(None) return Result.faild(msg="Edit Faild!") @@ -145,16 +168,21 @@ async def sql_editor_submit(sql_edit_context: ChatSqlEditContext = Body()): @router.get("/v1/editor/chart/list", response_model=Result[ChartList]) async def get_editor_chart_list(con_uid: str): - logger.info(f"get_editor_sql_rounds:{con_uid}", ) + logger.info( + f"get_editor_sql_rounds:{con_uid}", + ) history_mem = DuckdbHistoryMemory(con_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: - last_round = max(history_messages, key=lambda x: x['chat_order']) + last_round = max(history_messages, key=lambda x: x["chat_order"]) db_name = last_round["param_value"] for element in last_round["messages"]: if element["type"] == "ai": - chart_list: ChartList = ChartList(round=last_round['chat_order'], db_name=db_name, - charts=json.loads(element["data"]["content"])) + chart_list: ChartList = ChartList( + round=last_round["chat_order"], + db_name=db_name, + charts=json.loads(element["data"]["content"]), + ) return Result.succ(chart_list) return Result.faild(msg="Not have charts!") @@ -162,33 +190,40 @@ async def get_editor_chart_list(con_uid: str): @router.post("/v1/editor/chart/info", response_model=Result[ChartDetail]) async def get_editor_chart_info(param: dict = Body()): logger.info(f"get_editor_chart_info:{param}") - conv_uid = param['con_uid'] - chart_title = param['chart_title'] + conv_uid = param["con_uid"] + chart_title = param["chart_title"] history_mem = DuckdbHistoryMemory(conv_uid) history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: - last_round = max(history_messages, key=lambda x: x['chat_order']) + last_round = max(history_messages, key=lambda x: x["chat_order"]) db_name = last_round["param_value"] if not db_name: - logger.error("this dashboard dialogue version too old, can't support editor!") - return Result.faild(msg="this dashboard dialogue version too old, can't support editor!") + logger.error( + "this dashboard dialogue version too old, can't support editor!" + ) + return Result.faild( + msg="this dashboard dialogue version too old, can't support editor!" + ) for element in last_round["messages"]: if element["type"] == "view": - view_data: dict = json.loads(element["data"]["content"]); + view_data: dict = json.loads(element["data"]["content"]) charts: List = view_data.get("charts") - find_chart = list(filter(lambda x: x['chart_name'] == chart_title, charts))[0] + find_chart = list( + filter(lambda x: x["chart_name"] == chart_title, charts) + )[0] conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) - detail: ChartDetail = ChartDetail(chart_uid=find_chart['chart_uid'], - chart_type=find_chart['chart_type'], - chart_desc=find_chart['chart_desc'], - chart_sql=find_chart['chart_sql'], - db_name=db_name, - chart_name=find_chart['chart_name'], - chart_value=find_chart['values'], - table_value=conn.run(find_chart['chart_sql']) - ) + detail: ChartDetail = ChartDetail( + chart_uid=find_chart["chart_uid"], + chart_type=find_chart["chart_type"], + chart_desc=find_chart["chart_desc"], + chart_sql=find_chart["chart_sql"], + db_name=db_name, + chart_name=find_chart["chart_name"], + chart_value=find_chart["values"], + table_value=conn.run(find_chart["chart_sql"]), + ) return Result.succ(detail) return Result.faild(msg="Can't Find Chart Detail Info!") @@ -197,36 +232,39 @@ async def get_editor_chart_info(param: dict = Body()): @router.post("/v1/editor/chart/run", response_model=Result[ChartRunData]) async def editor_chart_run(run_param: dict = Body()): logger.info(f"editor_chart_run:{run_param}") - db_name = run_param['db_name'] - sql = run_param['sql'] - chart_type = run_param['chart_type'] + db_name = run_param["db_name"] + sql = run_param["sql"] + chart_type = run_param["chart_type"] if not db_name and not sql: return Result.faild("SQL run param error!") try: dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) colunms, sql_result = db_conn.query_ex(sql) - field_names, chart_values = dashboard_data_loader.get_chart_values_by_data(colunms, sql_result, sql) + field_names, chart_values = dashboard_data_loader.get_chart_values_by_data( + colunms, sql_result, sql + ) start_time = time.time() * 1000 # 计算执行耗时 end_time = time.time() * 1000 - sql_run_data: SqlRunData = SqlRunData(result_info="", - run_cost=(end_time - start_time) / 1000, - colunms=colunms, - values=sql_result - ) - return Result.succ(ChartRunData(sql_data=sql_run_data, chart_values=chart_values, chart_type = chart_type)) + sql_run_data: SqlRunData = SqlRunData( + result_info="", + run_cost=(end_time - start_time) / 1000, + colunms=colunms, + values=sql_result, + ) + return Result.succ( + ChartRunData( + sql_data=sql_run_data, chart_values=chart_values, chart_type=chart_type + ) + ) except Exception as e: - sql_result = SqlRunData(result_info=str(e), - run_cost=0, - colunms=[], - values=[] - ) - return Result.succ(ChartRunData(sql_data = sql_result, - chart_values=[], - chart_type = chart_type - )) + sql_result = SqlRunData(result_info=str(e), run_cost=0, colunms=[], values=[]) + return Result.succ( + ChartRunData(sql_data=sql_result, chart_values=[], chart_type=chart_type) + ) + @router.post("/v1/chart/editor/submit", response_model=Result[bool]) async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()): @@ -237,35 +275,53 @@ async def chart_editor_submit(chart_edit_context: ChatChartEditContext = Body()) dashboard_data_loader: DashboardDataLoader = DashboardDataLoader() db_conn = CFG.LOCAL_DB_MANAGE.get_connect(chart_edit_context.db_name) - edit_round = max(history_messages, key=lambda x: x['chat_order']) + edit_round = max(history_messages, key=lambda x: x["chat_order"]) if edit_round: try: for element in edit_round["messages"]: if element["type"] == "view": - view_data: dict = json.loads(element["data"]["content"]); + view_data: dict = json.loads(element["data"]["content"]) charts: List = view_data.get("charts") - find_chart = list(filter(lambda x: x['chart_name'] == chart_edit_context.chart_title, charts))[ - 0] + find_chart = list( + filter( + lambda x: x["chart_name"] + == chart_edit_context.chart_title, + charts, + ) + )[0] if chart_edit_context.new_chart_type: - find_chart['chart_type'] = chart_edit_context.new_chart_type + find_chart["chart_type"] = chart_edit_context.new_chart_type if chart_edit_context.new_comment: - find_chart['chart_desc'] = chart_edit_context.new_comment + find_chart["chart_desc"] = chart_edit_context.new_comment - field_names, chart_values = dashboard_data_loader.get_chart_values_by_conn(db_conn, - chart_edit_context.new_sql) - find_chart['chart_sql'] = chart_edit_context.new_sql - find_chart['values'] = [value.dict() for value in chart_values] - find_chart['column_name'] = field_names + ( + field_names, + chart_values, + ) = dashboard_data_loader.get_chart_values_by_conn( + db_conn, chart_edit_context.new_sql + ) + find_chart["chart_sql"] = chart_edit_context.new_sql + find_chart["values"] = [value.dict() for value in chart_values] + find_chart["column_name"] = field_names - element["data"]["content"] = json.dumps(view_data, ensure_ascii=False) + element["data"]["content"] = json.dumps( + view_data, ensure_ascii=False + ) if element["type"] == "ai": ai_resp: dict = json.loads(element["data"]["content"]) - edit_item = list(filter(lambda x: x['title'] == chart_edit_context.chart_title, ai_resp))[0] + edit_item = list( + filter( + lambda x: x["title"] == chart_edit_context.chart_title, + ai_resp, + ) + )[0] edit_item["sql"] = chart_edit_context.new_sql edit_item["showcase"] = chart_edit_context.new_chart_type edit_item["thoughts"] = chart_edit_context.new_comment - element["data"]["content"] = json.dumps(ai_resp, ensure_ascii=False) + element["data"]["content"] = json.dumps( + ai_resp, ensure_ascii=False + ) except Exception as e: logger.error(f"edit chart exception!{str(e)}", e) return Result.faild(msg=f"Edit chart exception!{str(e)}") diff --git a/pilot/openapi/api_v1/editor/sql_editor.py b/pilot/openapi/api_v1/editor/sql_editor.py index 6eeb6eeac..2cebe7f92 100644 --- a/pilot/openapi/api_v1/editor/sql_editor.py +++ b/pilot/openapi/api_v1/editor/sql_editor.py @@ -9,7 +9,7 @@ class DataNode(BaseModel): type: str = "" default_value: str = None - can_null: str = 'YES' + can_null: str = "YES" comment: str = None children: List = [] diff --git a/pilot/openapi/editor_view_model.py b/pilot/openapi/editor_view_model.py index ad35cdc3e..bbff0227f 100644 --- a/pilot/openapi/editor_view_model.py +++ b/pilot/openapi/editor_view_model.py @@ -10,11 +10,13 @@ class DbField(BaseModel): default_value: str = "" comment: str = "" + class DbTable(BaseModel): table_name: str comment: str colunm: List[DbField] + class ChatDbRounds(BaseModel): round: int db_name: str @@ -61,4 +63,3 @@ class ChatSqlEditContext(BaseModel): new_sql: str new_speak: str = "" - diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 58bc7d12f..5fe06916b 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -123,11 +123,7 @@ class BaseOutputParser(ABC): ai_response = ai_response.replace("\*", "*") ai_response = ai_response.replace("\t", "") - ai_response = ( - ai_response.strip() - .replace("\\n", " ") - .replace("\n", " ") - ) + ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ") print("un_stream ai response:", ai_response) return ai_response else: @@ -209,9 +205,9 @@ class BaseOutputParser(ABC): cleaned_output = self.__extract_json(cleaned_output) cleaned_output = ( cleaned_output.strip() - .replace("\\n", " ") - .replace("\n", " ") - .replace("\\", " ") + .replace("\\n", " ") + .replace("\n", " ") + .replace("\\", " ") ) cleaned_output = self.__illegal_json_ends(cleaned_output) return cleaned_output diff --git a/pilot/scene/base.py b/pilot/scene/base.py index eb9113e77..162759e3c 100644 --- a/pilot/scene/base.py +++ b/pilot/scene/base.py @@ -12,7 +12,6 @@ class Scene: is_inner: bool = False, show_disable=False, prepare_scene_code: str = None, - ): self.code = code self.name = name @@ -22,38 +21,39 @@ class Scene: self.show_disable = show_disable self.prepare_scene_code = prepare_scene_code + class ChatScene(Enum): ChatWithDbExecute = Scene( - code = "chat_with_db_execute", - name = "Chat Data", - describe = "Dialogue with your private data through natural language.", - param_types = ["DB Select"], + code="chat_with_db_execute", + name="Chat Data", + describe="Dialogue with your private data through natural language.", + param_types=["DB Select"], ) ExcelLearning = Scene( - code = "excel_learning", - name = "Excel Learning", - describe = "Analyze and summarize your excel files.", - is_inner = True, + code="excel_learning", + name="Excel Learning", + describe="Analyze and summarize your excel files.", + is_inner=True, ) ChatExcel = Scene( - code = "chat_excel", - name = "Chat Excel", - describe = "Dialogue with your excel, use natural language.", + code="chat_excel", + name="Chat Excel", + describe="Dialogue with your excel, use natural language.", param_types=["File Select"], - prepare_scene_code="excel_learning" + prepare_scene_code="excel_learning", ) ChatWithDbQA = Scene( - code = "chat_with_db_qa", - name = "Chat DB", - describe = "Have a Professional Conversation with Metadata.", - param_types = ["DB Select"], + code="chat_with_db_qa", + name="Chat DB", + describe="Have a Professional Conversation with Metadata.", + param_types=["DB Select"], ) ChatExecution = Scene( - code = "chat_execution", - name = "Use Plugin", - describe = "Use tools through dialogue to accomplish your goals.", - param_types = ["Plugin Select"], + code="chat_execution", + name="Use Plugin", + describe="Use tools through dialogue to accomplish your goals.", + param_types=["Plugin Select"], ) InnerChatDBSummary = Scene( @@ -78,7 +78,7 @@ class ChatScene(Enum): @staticmethod def of_mode(mode): - return [x for x in ChatScene._value_ if x.code == mode][0] + return [x for x in ChatScene._value_ if x.code == mode][0] @staticmethod def is_valid_mode(mode): @@ -100,4 +100,4 @@ class ChatScene(Enum): return self._value_.show_disable def is_inner(self): - return self._value_.is_inner \ No newline at end of file + return self._value_.is_inner diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 6a845154e..fe515afb5 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -60,11 +60,7 @@ class BaseChat(ABC): arbitrary_types_allowed = True def __init__( - self, - chat_mode, - chat_session_id, - current_user_input, - select_param: Any = None + self, chat_mode, chat_session_id, current_user_input, select_param: Any = None ): self.chat_session_id = chat_session_id self.chat_mode = chat_mode @@ -72,7 +68,6 @@ class BaseChat(ABC): self.llm_model = CFG.LLM_MODEL self.llm_echo = False - ### load prompt template # self.prompt_template: PromptTemplate = CFG.prompt_templates[ # self.chat_mode.value() @@ -182,6 +177,7 @@ class BaseChat(ABC): return response else: from pilot.server.llmserver import worker + return worker.generate_stream_gate(payload) except Exception as e: print(traceback.format_exc()) @@ -235,7 +231,9 @@ class BaseChat(ABC): ### llm speaker speak_to_user = self.get_llm_speak(prompt_define_response) - view_message = self.prompt_template.output_parser.parse_view_response(speak_to_user, result) + view_message = self.prompt_template.output_parser.parse_view_response( + speak_to_user, result + ) self.current_message.add_view_message(view_message) except Exception as e: print(traceback.format_exc()) @@ -253,10 +251,8 @@ class BaseChat(ABC): else: return self.nostream_call() - def prepare(self): - pass - + pass def generate_llm_text(self) -> str: warnings.warn("This method is deprecated - please use `generate_llm_messages`.") @@ -363,9 +359,7 @@ class BaseChat(ABC): ) if len(self.history_message) > self.chat_retention_rounds: for first_message in self.history_message[0]["messages"]: - if not first_message["type"] in [ - ModelMessageRoleType.VIEW - ]: + if not first_message["type"] in [ModelMessageRoleType.VIEW]: message_type = first_message["type"] message_content = first_message["data"]["content"] history_text += ( @@ -394,7 +388,9 @@ class BaseChat(ABC): + self.prompt_template.sep ) history_messages.append( - ModelMessage(role=message_type, content=message_content) + ModelMessage( + role=message_type, content=message_content + ) ) else: diff --git a/pilot/scene/base_message.py b/pilot/scene/base_message.py index 168fb2bb9..75c0aad97 100644 --- a/pilot/scene/base_message.py +++ b/pilot/scene/base_message.py @@ -63,6 +63,7 @@ class AIMessage(BaseMessage): class ViewMessage(BaseMessage): """Type of message that is spoken by the AI.""" + example: bool = False @property @@ -73,6 +74,7 @@ class ViewMessage(BaseMessage): class SystemMessage(BaseMessage): """Type of message that is a system message.""" + @property def type(self) -> str: """Type of the message, used for serialization.""" diff --git a/pilot/scene/chat_dashboard/chat.py b/pilot/scene/chat_dashboard/chat.py index 29f962dd8..faa255887 100644 --- a/pilot/scene/chat_dashboard/chat.py +++ b/pilot/scene/chat_dashboard/chat.py @@ -21,9 +21,15 @@ class ChatDashboard(BaseChat): report_name: str """Number of results to return from the query""" - def __init__(self, chat_session_id, user_input, select_param:str = "", report_name:str="report"): + def __init__( + self, + chat_session_id, + user_input, + select_param: str = "", + report_name: str = "report", + ): """ """ - self.db_name=select_param + self.db_name = select_param super().__init__( chat_mode=ChatScene.ChatDashboard, chat_session_id=chat_session_id, @@ -80,7 +86,9 @@ class ChatDashboard(BaseChat): dashboard_data_loader = DashboardDataLoader() for chart_item in prompt_response: try: - field_names, values = dashboard_data_loader.get_chart_values_by_conn(self.database, chart_item.sql) + field_names, values = dashboard_data_loader.get_chart_values_by_conn( + self.database, chart_item.sql + ) chart_datas.append( ChartData( chart_uid=str(uuid.uuid1()), @@ -101,5 +109,3 @@ class ChatDashboard(BaseChat): template_introduce=None, charts=chart_datas, ) - - diff --git a/pilot/scene/chat_dashboard/data_loader.py b/pilot/scene/chat_dashboard/data_loader.py index 5a7b78bf4..945bf76cb 100644 --- a/pilot/scene/chat_dashboard/data_loader.py +++ b/pilot/scene/chat_dashboard/data_loader.py @@ -11,15 +11,14 @@ logger = build_logger("dashboard_data", LOGDIR + "dashboard_data.log") class DashboardDataLoader: - def get_sql_value(self, db_conn, chart_sql: str): - return db_conn.query_ex(chart_sql) + return db_conn.query_ex(chart_sql) - def get_chart_values_by_conn(self, db_conn, chart_sql: str) : - field_names, datas = db_conn.query_ex(chart_sql) - return self.get_chart_values_by_data(field_names, datas, chart_sql) + def get_chart_values_by_conn(self, db_conn, chart_sql: str): + field_names, datas = db_conn.query_ex(chart_sql) + return self.get_chart_values_by_data(field_names, datas, chart_sql) - def get_chart_values_by_data(self, field_names, datas, chart_sql: str) : + def get_chart_values_by_data(self, field_names, datas, chart_sql: str): logger.info(f"get_chart_values_by_conn:{chart_sql}") try: values: List[ValueItem] = [] @@ -57,7 +56,7 @@ class DashboardDataLoader: logger.debug("Prepare Chart Data Faild!" + str(e)) raise ValueError("Prepare Chart Data Faild!") - def get_chart_values_by_db(self, db_name: str, chart_sql: str) : + def get_chart_values_by_db(self, db_name: str, chart_sql: str): logger.info(f"get_chart_values_by_db:{db_name},{chart_sql}") db_conn = CFG.LOCAL_DB_MANAGE.get_connect(db_name) return self.get_chart_values_by_conn(db_conn, chart_sql) diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index ca23d3914..806fada18 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -27,6 +27,7 @@ CFG = Config() class ChatExcel(BaseChat): chat_scene: str = ChatScene.ChatExcel.value() chat_retention_rounds = 1 + def __init__(self, chat_session_id, user_input, select_param: str = ""): chat_mode = ChatScene.ChatExcel @@ -34,9 +35,11 @@ class ChatExcel(BaseChat): if has_path(select_param): self.excel_reader = ExcelReader(select_param) else: - self.excel_reader = ExcelReader(os.path.join( - KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param - )) + self.excel_reader = ExcelReader( + os.path.join( + KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), select_param + ) + ) super().__init__( chat_mode=chat_mode, @@ -70,12 +73,9 @@ class ChatExcel(BaseChat): ] return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings)) - def generate_input_values(self): - - input_values = { - "user_input": self.current_user_input, + "user_input": self.current_user_input, "table_name": self.excel_reader.table_name, "disply_type": self._generate_numbered_list(), } @@ -87,23 +87,21 @@ class ChatExcel(BaseChat): return None chat_param = { "chat_session_id": self.chat_session_id, - "user_input": "[" + self.excel_reader.excel_file_name +"]" + " Analysis!", + "user_input": "[" + self.excel_reader.excel_file_name + "]" + " Analysis!", "parent_mode": self.chat_mode, - "select_param":self.excel_reader.excel_file_name, - "excel_reader": self.excel_reader + "select_param": self.excel_reader.excel_file_name, + "excel_reader": self.excel_reader, } learn_chat = ExcelLearning(**chat_param) result = learn_chat.nostream_call() return result - def do_action(self, prompt_response): print(f"do_action:{prompt_response}") # colunms, datas = self.excel_reader.run(prompt_response.sql) - param= { + param = { "speak": prompt_response.thoughts, - "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql) + "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql), } return CFG.command_disply.call(prompt_response.display, **param) - diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py b/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py index 8ec352721..7b4b7f1dc 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py @@ -2,7 +2,9 @@ import json from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import ChatExcelOutputParser +from pilot.scene.chat_data.chat_excel.excel_analyze.out_parser import ( + ChatExcelOutputParser, +) from pilot.common.schema import SeparatorStyle CFG = Config() @@ -39,9 +41,9 @@ SQL中需要使用的表名是: {table_name} """ RESPONSE_FORMAT_SIMPLE = { - "sql": "analysis SQL", - "thoughts": "Current thinking and value of data analysis", - "display": "display type name" + "sql": "analysis SQL", + "thoughts": "Current thinking and value of data analysis", + "display": "display type name", } _DEFAULT_TEMPLATE = ( @@ -67,9 +69,8 @@ prompt = PromptTemplate( output_parser=ChatExcelOutputParser( sep=PROMPT_SEP, is_stream_out=PROMPT_NEED_NEED_STREAM_OUT ), - need_historical_messages = True, + need_historical_messages=True, # example_selector=sql_data_example, temperature=PROMPT_TEMPERATURE, ) CFG.prompt_template_registry.register(prompt, is_default=True) - diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index 8e5028f79..e49b2f47d 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -23,7 +23,14 @@ CFG = Config() class ExcelLearning(BaseChat): chat_scene: str = ChatScene.ExcelLearning.value() - def __init__(self, chat_session_id, user_input, parent_mode: Any=None, select_param:str=None, excel_reader:Any=None): + def __init__( + self, + chat_session_id, + user_input, + parent_mode: Any = None, + select_param: str = None, + excel_reader: Any = None, + ): chat_mode = ChatScene.ExcelLearning """ """ self.excel_file_path = select_param @@ -31,20 +38,19 @@ class ExcelLearning(BaseChat): super().__init__( chat_mode=chat_mode, chat_session_id=chat_session_id, - current_user_input = user_input, + current_user_input=user_input, select_param=select_param, ) if parent_mode: self.current_message.chat_mode = parent_mode.value() def generate_input_values(self): - colunms, datas = self.excel_reader.get_sample_data() datas.insert(0, colunms) input_values = { - "data_example": json.dumps(self.excel_reader.get_sample_data(), cls=DateTimeEncoder), + "data_example": json.dumps( + self.excel_reader.get_sample_data(), cls=DateTimeEncoder + ), } return input_values - - diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py index 5bb9ba2d3..3ded0c792 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -37,23 +37,23 @@ class LearningExcelOutputParser(BaseOutputParser): plans = response[key] return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans) - - def parse_view_response(self, speak, data) -> str: ### tool out data to table view html_title = f"### **数据简介**\n{data.desciption} " html_colunms = f"### **数据结构**\n" column_index = 0 for item in data.clounms: - column_index +=1 + column_index += 1 keys = item.keys() for key in keys: - html_colunms = html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" + html_colunms = ( + html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" + ) html_plans = f"### **分析计划**\n" index = 0 for item in data.plans: - index +=1 + index += 1 html_plans = html_plans + f"{item} \n" html = f"""{html_title}\n{html_colunms}\n{html_plans}""" return html diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py index d4ee0d6b1..23a3696ba 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py @@ -2,7 +2,9 @@ import json from pilot.prompts.prompt_new import PromptTemplate from pilot.configs.config import Config from pilot.scene.base import ChatScene -from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import LearningExcelOutputParser +from pilot.scene.chat_data.chat_excel.excel_learning.out_parser import ( + LearningExcelOutputParser, +) from pilot.common.schema import SeparatorStyle CFG = Config() @@ -29,7 +31,7 @@ _DEFAULT_TEMPLATE_ZH = """ {response} """ -RESPONSE_FORMAT_SIMPLE = { +RESPONSE_FORMAT_SIMPLE = { "DataAnalysis": "数据内容分析总结", "ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}], "AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"], @@ -63,5 +65,3 @@ prompt = PromptTemplate( temperature=PROMPT_TEMPERATURE, ) CFG.prompt_template_registry.register(prompt, is_default=True) - - diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/test.py b/pilot/scene/chat_data/chat_excel/excel_learning/test.py index 4fe66ae06..ecdd03867 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/test.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/test.py @@ -7,7 +7,7 @@ import seaborn as sns import matplotlib.pyplot as plt import time from fsspec import filesystem -import spatial +import spatial from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader @@ -18,22 +18,35 @@ if __name__ == "__main__": excel_reader = ExcelReader("/Users/tuyang.yhj/Downloads/example.xlsx") # colunms, datas = excel_reader.run( "SELECT CONCAT(Year, '-', Quarter) AS QuarterYear, SUM(Sales) AS TotalSales FROM example GROUP BY QuarterYear ORDER BY QuarterYear") - colunms, datas = excel_reader.run( """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """) - df = excel_reader.get_df_by_sql_ex("SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;") + colunms, datas = excel_reader.run( + """ SELECT Year, SUM(Sales) AS Total_Sales FROM example GROUP BY Year ORDER BY Year; """ + ) + df = excel_reader.get_df_by_sql_ex( + "SELECT Country, SUM(Profit) AS Total_Profit FROM example GROUP BY Country;" + ) columns = df.columns.tolist() plt.rcParams["font.family"] = ["sans-serif"] rc = {"font.sans-serif": "SimHei", "axes.unicode_minus": False} - sns.set_style(rc={'font.sans-serif': "Microsoft Yahei"}) + sns.set_style(rc={"font.sans-serif": "Microsoft Yahei"}) sns.set(context="notebook", style="ticks", color_codes=True, rc=rc) sns.set_palette("Set3") # 设置颜色主题 # fig, ax = plt.pie(df[columns[1]], labels=df[columns[0]], autopct='%1.1f%%', startangle=90) fig, ax = plt.subplots(figsize=(8, 5), dpi=100) plt.subplots_adjust(top=0.9) - ax = df.plot(kind='pie', y=columns[1], ax=ax, labels=df[columns[0]].values, startangle=90, autopct='%1.1f%%') + ax = df.plot( + kind="pie", + y=columns[1], + ax=ax, + labels=df[columns[0]].values, + startangle=90, + autopct="%1.1f%%", + ) # 手动设置 labels 的位置和大小 - ax.legend(loc='center left', bbox_to_anchor=(-1, 0.5, 0,0), labels=None, fontsize=10) - plt.axis('equal') # 使饼图为正圆形 + ax.legend( + loc="center left", bbox_to_anchor=(-1, 0.5, 0, 0), labels=None, fontsize=10 + ) + plt.axis("equal") # 使饼图为正圆形 plt.show() # # diff --git a/pilot/scene/chat_data/chat_excel/excel_reader.py b/pilot/scene/chat_data/chat_excel/excel_reader.py index d5afd1a9c..ff4759fa2 100644 --- a/pilot/scene/chat_data/chat_excel/excel_reader.py +++ b/pilot/scene/chat_data/chat_excel/excel_reader.py @@ -7,11 +7,13 @@ import numpy as np from pilot.common.pd_utils import csv_colunm_foramt -def excel_colunm_format(old_name:str)->str: + +def excel_colunm_format(old_name: str) -> str: new_column = old_name.strip() new_column = new_column.replace(" ", "_") return new_column + def add_quotes(sql, column_names=[]): sql = sql.replace("`", "") parsed = sqlparse.parse(sql) @@ -20,44 +22,51 @@ def add_quotes(sql, column_names=[]): deep_quotes(token, column_names) return str(parsed[0]) + def deep_quotes(token, column_names=[]): - if hasattr(token, "tokens") : + if hasattr(token, "tokens"): for token_child in token.tokens: deep_quotes(token_child, column_names) else: if token.ttype == sqlparse.tokens.Name: - if len(column_names) >0: + if len(column_names) > 0: if token.value in column_names: token.value = f'"{token.value.replace("`", "")}"' else: token.value = f'"{token.value.replace("`", "")}"' + def is_chinese(string): # 使用正则表达式匹配中文字符 - pattern = re.compile(r'[一-龥]') + pattern = re.compile(r"[一-龥]") match = re.search(pattern, string) return match is not None + class ExcelReader: - def __init__(self, file_path): - file_name = os.path.basename(file_path) file_name_without_extension = os.path.splitext(file_name)[0] self.excel_file_name = file_name self.extension = os.path.splitext(file_name)[1] # read excel file - if file_path.endswith('.xlsx') or file_path.endswith('.xls'): + if file_path.endswith(".xlsx") or file_path.endswith(".xls"): df_tmp = pd.read_excel(file_path) - self.df = pd.read_excel(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) - elif file_path.endswith('.csv'): + self.df = pd.read_excel( + file_path, + converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}, + ) + elif file_path.endswith(".csv"): df_tmp = pd.read_csv(file_path) - self.df = pd.read_csv(file_path, converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}) + self.df = pd.read_csv( + file_path, + converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])}, + ) else: raise ValueError("Unsupported file format.") - self.df.replace('', np.nan, inplace=True) + self.df.replace("", np.nan, inplace=True) self.columns_map = {} for column_name in df_tmp.columns: self.columns_map.update({column_name: excel_colunm_format(column_name)}) @@ -66,18 +75,17 @@ class ExcelReader: except Exception as e: print("transfor column error!" + column_name) - self.df = self.df.rename(columns=lambda x: x.strip().replace(' ', '_')) + self.df = self.df.rename(columns=lambda x: x.strip().replace(" ", "_")) # connect DuckDB - self.db = duckdb.connect(database=':memory:', read_only=False) - + self.db = duckdb.connect(database=":memory:", read_only=False) self.table_name = file_name_without_extension # write data in duckdb self.db.register(self.table_name, self.df) def run(self, sql): - if f'"{self.table_name}"' not in sql: + if f'"{self.table_name}"' not in sql: sql = sql.replace(self.table_name, f'"{self.table_name}"') sql = add_quotes(sql, self.columns_map.values()) print(f"excute sql:{sql}") @@ -92,5 +100,4 @@ class ExcelReader: return pd.DataFrame(values, columns=colunms) def get_sample_data(self): - return self.run(f'SELECT * FROM {self.table_name} LIMIT 5;') - + return self.run(f"SELECT * FROM {self.table_name} LIMIT 5;") diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index d6c2b0536..595e65e2e 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -21,7 +21,7 @@ class ChatWithDbAutoExecute(BaseChat): """Number of results to return from the query""" - def __init__(self, chat_session_id, user_input, select_param:str = ""): + def __init__(self, chat_session_id, user_input, select_param: str = ""): chat_mode = ChatScene.ChatWithDbExecute self.db_name = select_param """ """ @@ -47,7 +47,9 @@ class ChatWithDbAutoExecute(BaseChat): client = DBSummaryClient() try: table_infos = client.get_db_summary( - dbname=self.db_name, query=self.current_user_input, topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE + dbname=self.db_name, + query=self.current_user_input, + topk=CFG.KNOWLEDGE_SEARCH_TOP_SIZE, ) except Exception as e: print("db summary find error!" + str(e)) @@ -65,4 +67,4 @@ class ChatWithDbAutoExecute(BaseChat): def do_action(self, prompt_response): print(f"do_action:{prompt_response}") - return self.database.run( prompt_response.sql) + return self.database.run(prompt_response.sql) diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index ce20c6987..f19903fc8 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -8,6 +8,7 @@ from pilot.out_parser.base import BaseOutputParser, T from pilot.configs.model_config import LOGDIR from pilot.configs.config import Config from pilot.scene.chat_db.data_loader import DbDataLoader + CFG = Config() @@ -49,6 +50,4 @@ class DbChatOutputParser(BaseOutputParser): view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") return view_text else: - return data_loader.get_table_view_by_conn(data, speak) - - + return data_loader.get_table_view_by_conn(data, speak) diff --git a/pilot/scene/chat_db/data_loader.py b/pilot/scene/chat_db/data_loader.py index b41640ceb..b1f9ea93d 100644 --- a/pilot/scene/chat_db/data_loader.py +++ b/pilot/scene/chat_db/data_loader.py @@ -1,8 +1,7 @@ import pandas as pd + class DbDataLoader: - - def get_table_view_by_conn(self, data, speak): ### tool out data to table view if len(data) <= 1: @@ -12,4 +11,4 @@ class DbDataLoader: table_str = "".join(html_table.split()) html = f"""
{table_str}
""" view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ") - return view_text \ No newline at end of file + return view_text diff --git a/pilot/scene/chat_db/professional_qa/chat.py b/pilot/scene/chat_db/professional_qa/chat.py index 067022e1b..18b587acf 100644 --- a/pilot/scene/chat_db/professional_qa/chat.py +++ b/pilot/scene/chat_db/professional_qa/chat.py @@ -19,7 +19,7 @@ class ChatWithDbQA(BaseChat): """Number of results to return from the query""" - def __init__(self, chat_session_id, user_input, select_param:str = ""): + def __init__(self, chat_session_id, user_input, select_param: str = ""): """ """ self.db_name = select_param super().__init__( diff --git a/pilot/scene/chat_execution/chat.py b/pilot/scene/chat_execution/chat.py index 51813ce60..43727f613 100644 --- a/pilot/scene/chat_execution/chat.py +++ b/pilot/scene/chat_execution/chat.py @@ -20,12 +20,7 @@ class ChatWithPlugin(BaseChat): plugins_prompt_generator: PluginPromptGenerator select_plugin: str = None - def __init__( - self, - chat_session_id, - user_input, - select_param: str = None - ): + def __init__(self, chat_session_id, user_input, select_param: str = None): self.plugin_selector = select_param super().__init__( chat_mode=ChatScene.ChatExecution, diff --git a/pilot/scene/chat_knowledge/v1/chat.py b/pilot/scene/chat_knowledge/v1/chat.py index 316e38608..e4213f447 100644 --- a/pilot/scene/chat_knowledge/v1/chat.py +++ b/pilot/scene/chat_knowledge/v1/chat.py @@ -30,7 +30,7 @@ class ChatKnowledge(BaseChat): """Number of results to return from the query""" - def __init__(self, chat_session_id, user_input, select_param: str = None): + def __init__(self, chat_session_id, user_input, select_param: str = None): """ """ self.knowledge_space = select_param super().__init__( diff --git a/pilot/scene/chat_normal/chat.py b/pilot/scene/chat_normal/chat.py index cd3cae7bc..531319366 100644 --- a/pilot/scene/chat_normal/chat.py +++ b/pilot/scene/chat_normal/chat.py @@ -30,7 +30,6 @@ class ChatNormal(BaseChat): input_values = {"input": self.current_user_input} return input_values - @property def chat_type(self) -> str: return ChatScene.ChatNormal.value diff --git a/pilot/scene/message.py b/pilot/scene/message.py index ea74bae2f..73a141988 100644 --- a/pilot/scene/message.py +++ b/pilot/scene/message.py @@ -117,12 +117,10 @@ def _conversation_to_dic(once: OnceConversation) -> dict: "tokens": once.tokens if once.tokens else 0, "messages": messages_to_dict(once.messages), "param_type": once.param_type, - "param_value": once.param_value + "param_value": once.param_value, } - - def conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]: return [_conversation_to_dic(m) for m in conversations] diff --git a/pilot/server/base.py b/pilot/server/base.py index 22cbef6bc..3dfaa1c7e 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -64,16 +64,12 @@ def server_init(args): cfg.command_registry = command_registry - command_disply_commands = [ "pilot.commands.disply_type.show_chart_gen", "pilot.commands.disply_type.show_table_gen", "pilot.commands.disply_type.show_text_gen", ] - command_disply_registry = CommandRegistry() + command_disply_registry = CommandRegistry() for command in command_disply_commands: command_disply_registry.import_commands(command) cfg.command_disply = command_disply_registry - - - diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index cac686684..6c0753fa2 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -81,7 +81,9 @@ app.include_router(knowledge_router) # app.include_router(api_editor_route_v1) os.makedirs(static_message_img_path, exist_ok=True) -app.mount("/images", StaticFiles(directory=static_message_img_path, html=True), name="static2") +app.mount( + "/images", StaticFiles(directory=static_message_img_path, html=True), name="static2" +) app.mount("/_next/static", StaticFiles(directory=static_file_path + "/_next/static")) app.mount("/", StaticFiles(directory=static_file_path, html=True), name="static") diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index ac23953b5..99b992698 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -19,7 +19,6 @@ from pilot.utils import build_logger logger = build_logger("db_summary", LOGDIR + "db_summary.log") - CFG = Config() chat_factory = ChatFactory() @@ -145,7 +144,9 @@ class DBSummaryClient: try: self.db_summary_embedding(item["db_name"], item["db_type"]) except Exception as e: - logger.warn(f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e) + logger.warn( + f'{item["db_name"]}, {item["db_type"]} summary error!{str(e)}', e + ) def init_db_profile(self, db_summary_client, dbname, embeddings): profile_store_config = { diff --git a/pilot/utils.py b/pilot/utils.py index c44a4ea2d..d05387056 100644 --- a/pilot/utils.py +++ b/pilot/utils.py @@ -75,7 +75,6 @@ def build_logger(logger_name, logger_filename): item.addHandler(handler) logging.basicConfig(level=logging.INFO) - # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO)