From def77e66fe57f0917d626028f67e6e393fc7e9ce Mon Sep 17 00:00:00 2001
From: yhjun1026 <460342015@qq.com>
Date: Fri, 13 Oct 2023 15:46:07 +0800
Subject: [PATCH] feat(Agent): ChatAgent And AgentHub
1.LLM TongYiQianWen WenXinYiYan ZhiPu support
---
.../agent/commands/command_mange.py | 36 ++++--
.../commands/disply_type/show_chart_gen.py | 113 ++++++++++--------
.../commands/disply_type/show_table_gen.py | 6 +-
pilot/common/string_utils.py | 11 ++
pilot/model/proxy/llms/tongyi.py | 49 ++++----
pilot/model/proxy/llms/wenxin.py | 35 ++++--
pilot/model/proxy/llms/zhipu.py | 28 ++++-
.../chat_excel/excel_analyze/prompt.py | 15 +--
.../chat_data/chat_excel/excel_reader.py | 66 +++++-----
pilot/server/static/404.html | 2 +-
pilot/server/static/404/index.html | 2 +-
.../_buildManifest.js | 0
.../_ssgManifest.js | 0
.../static/chat/[scene]/[id]/index.html | 2 +-
pilot/server/static/database/index.html | 2 +-
.../datastores/documents/chunklist/index.html | 2 +-
.../static/datastores/documents/index.html | 2 +-
pilot/server/static/datastores/index.html | 2 +-
pilot/server/static/index.html | 2 +-
pilot/server/static/prompt/index.html | 2 +-
setup.py | 3 +
21 files changed, 234 insertions(+), 146 deletions(-)
rename pilot/server/static/_next/static/{TTwLTfqU5PAKOW_pnzHId => g-7DkRI9SHIZ9nvhXpw7n}/_buildManifest.js (100%)
rename pilot/server/static/_next/static/{TTwLTfqU5PAKOW_pnzHId => g-7DkRI9SHIZ9nvhXpw7n}/_ssgManifest.js (100%)
diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py
index 48866ae22..951b60639 100644
--- a/pilot/base_modules/agent/commands/command_mange.py
+++ b/pilot/base_modules/agent/commands/command_mange.py
@@ -87,6 +87,11 @@ class CommandRegistry:
if hasattr(reloaded_module, "register"):
reloaded_module.register(self)
+ def is_valid_command(self, name:str)-> bool:
+ if name not in self.commands:
+ return False
+ else:
+ return True
def get_command(self, name: str) -> Callable[..., Any]:
return self.commands[name]
@@ -229,6 +234,18 @@ class ApiCall:
return True
return False
+ def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
+ error_md_tags = ["```", "```python", "```xml", "```json", "```markdown"]
+ if include_end == False:
+ md_tag_end = ""
+ else:
+ md_tag_end = "```"
+ for tag in error_md_tags:
+ all_context = all_context.replace(tag + api_context + md_tag_end, api_context)
+ all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context)
+ all_context = all_context.replace(tag + api_context, api_context)
+ return all_context
+
def api_view_context(self, all_context: str, display_mode: bool = False):
error_mk_tags = ["```", "```python", "```xml"]
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
@@ -237,22 +254,20 @@ class ApiCall:
if api_status is not None:
if display_mode:
if api_status.api_result:
- for tag in error_mk_tags:
- all_context = all_context.replace(tag + api_context + "```", api_context)
+ all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, api_status.api_result)
else:
if api_status.status == Status.FAILED.value:
+ all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, f"""\nERROR!{api_status.err_msg}\n """)
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
- for tag in error_mk_tags:
- all_context = all_context.replace(tag + api_context + "```", api_context)
+ all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n')
else:
- for tag in error_mk_tags:
- all_context = all_context.replace(tag + api_context + "```", api_context)
+ all_context = self.__deal_error_md_tags(all_context, api_context, False)
all_context = all_context.replace(api_context, self.to_view_text(api_status))
else:
@@ -260,13 +275,13 @@ class ApiCall:
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
+ for tag in error_mk_tags:
+ all_context = all_context.replace(tag + api_context , api_context)
all_context = all_context.replace(api_context, f'\nWaiting...{cost_str}S\n')
return all_context
def update_from_context(self, all_context):
- logging.info(f"from_context:{all_context}")
-
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
@@ -308,7 +323,6 @@ class ApiCall:
return result.decode("utf-8")
def run(self, llm_text):
- print(f"stream_plugin_call:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
@@ -327,8 +341,6 @@ class ApiCall:
return self.api_view_context(llm_text)
def run_display_sql(self, llm_text, sql_run_func):
- print(f"get_display_sql:{llm_text}")
-
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
@@ -343,7 +355,7 @@ class ApiCall:
param = {
"df": sql_run_func(sql),
}
- if self.display_registry.get_command(value.name):
+ if self.display_registry.is_valid_command(value.name):
value.api_result = self.display_registry.call(value.name, **param)
else:
value.api_result = self.display_registry.call("response_table", **param)
diff --git a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
index 793055257..7807f3818 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_chart_gen.py
@@ -12,7 +12,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
-
+from pilot.common.string_utils import is_scientific_notation
import logging
logger = logging.getLogger(__name__)
@@ -88,6 +88,14 @@ def zh_font_set():
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
+def format_axis(value, pos):
+ # 判断是否为数字
+ if is_scientific_notation(value):
+ # 判断是否需要进行非科学计数法格式化
+
+ return '{:.2f}'.format(value)
+ return value
+
@command(
"response_line_chart",
@@ -98,58 +106,61 @@ def response_line_chart( df: DataFrame) -> str:
logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data!")
+ try:
+ # set font
+ # zh_font_set()
+ font_names = [
+ "Heiti TC",
+ "Songti SC",
+ "STHeiti Light",
+ "Microsoft YaHei",
+ "SimSun",
+ "SimHei",
+ "KaiTi",
+ ]
+ fm = FontManager()
+ mat_fonts = set(f.name for f in fm.ttflist)
+ can_use_fonts = []
+ for font_name in font_names:
+ if font_name in mat_fonts:
+ can_use_fonts.append(font_name)
+ if len(can_use_fonts) > 0:
+ plt.rcParams["font.sans-serif"] = can_use_fonts
- # set font
- # zh_font_set()
- font_names = [
- "Heiti TC",
- "Songti SC",
- "STHeiti Light",
- "Microsoft YaHei",
- "SimSun",
- "SimHei",
- "KaiTi",
- ]
- fm = FontManager()
- mat_fonts = set(f.name for f in fm.ttflist)
- can_use_fonts = []
- for font_name in font_names:
- if font_name in mat_fonts:
- can_use_fonts.append(font_name)
- if len(can_use_fonts) > 0:
- plt.rcParams["font.sans-serif"] = can_use_fonts
+ rc = {"font.sans-serif": can_use_fonts}
+ plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
- rc = {"font.sans-serif": can_use_fonts}
- plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
+ sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
+ sns.set_palette("Set3") # 设置颜色主题
+ sns.set_style("dark")
+ sns.color_palette("hls", 10)
+ sns.hls_palette(8, l=0.5, s=0.7)
+ sns.set(context="notebook", style="ticks", rc=rc)
- sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
- sns.set_palette("Set3") # 设置颜色主题
- sns.set_style("dark")
- sns.color_palette("hls", 10)
- sns.hls_palette(8, l=0.5, s=0.7)
- sns.set(context="notebook", style="ticks", rc=rc)
+ fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
+ x, y, non_num_columns, num_colmns = data_pre_classification(df)
+ # ## 复杂折线图实现
+ if len(num_colmns) > 0:
+ num_colmns.append(y)
+ df_melted = pd.melt(
+ df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
+ )
+ sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
+ else:
+ sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
- fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
- x, y, non_num_columns, num_colmns = data_pre_classification(df)
- # ## 复杂折线图实现
- if len(num_colmns) > 0:
- num_colmns.append(y)
- df_melted = pd.melt(
- df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
- )
- sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
- else:
- sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
+ ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
+ # ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
- ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
- ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
+ chart_name = "line_" + str(uuid.uuid1()) + ".png"
+ chart_path = static_message_img_path + "/" + chart_name
+ plt.savefig(chart_path, dpi=100, transparent=True)
- chart_name = "line_" + str(uuid.uuid1()) + ".png"
- chart_path = static_message_img_path + "/" + chart_name
- plt.savefig(chart_path, bbox_inches="tight", dpi=100)
-
- html_img = f""""""
- return html_img
+ html_img = f"""
"""
+ return html_img
+ except Exception as e:
+ logging.error("Draw Line Chart Faild!" + str(e), e)
+ raise ValueError("Draw Line Chart Faild!" + str(e))
@command(
@@ -230,12 +241,12 @@ def response_bar_chart( df: DataFrame) -> str:
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
# 设置 y 轴刻度格式为普通数字格式
- ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
- ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
+ ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
+ # ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
- plt.savefig(chart_path, bbox_inches="tight", dpi=100)
+ plt.savefig(chart_path, dpi=100,transparent=True)
html_img = f"""
"""
return html_img
@@ -289,7 +300,7 @@ def response_pie_chart(df: DataFrame) -> str:
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
- plt.savefig(chart_path, bbox_inches="tight", dpi=100)
+ plt.savefig(chart_path, bbox_inches="tight", dpi=100, transparent=True)
html_img = f"""
"""
diff --git a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
index 96c6f906c..9afd14ca5 100644
--- a/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
+++ b/pilot/base_modules/agent/commands/disply_type/show_table_gen.py
@@ -14,6 +14,6 @@ def response_table(df: DataFrame) -> str:
logger.info(f"response_table")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
- html = f"""
Fast, reliable, scalable open-source relational database management system.
Powerful, scalable, secure relational database system by Microsoft.
In-memory analytical database with efficient query processing.
Lightweight embedded relational database with simplicity and portability.
Columnar database for high-performance analytics and real-time queries.
Robust, scalable, secure relational database widely used in enterprises.
Easy-to-use relational database for small-scale applications by Microsoft.
Flexible, scalable NoSQL document database for web and mobile apps.
Scalable, secure relational database system developed by IBM.
Distributed, scalable NoSQL database for large structured/semi-structured data.
Fast, versatile in-memory data structure store as cache, DB, or broker.
Scalable, fault-tolerant distributed NoSQL database for large data.
High-performance NoSQL document database with distributed architecture.
Powerful open-source relational database with extensibility and SQL standards.
Unified engine for large-scale data analytics.
Fast, reliable, scalable open-source relational database management system.
Powerful, scalable, secure relational database system by Microsoft.
In-memory analytical database with efficient query processing.
Lightweight embedded relational database with simplicity and portability.
Columnar database for high-performance analytics and real-time queries.
Robust, scalable, secure relational database widely used in enterprises.
Easy-to-use relational database for small-scale applications by Microsoft.
Flexible, scalable NoSQL document database for web and mobile apps.
Scalable, secure relational database system developed by IBM.
Distributed, scalable NoSQL database for large structured/semi-structured data.
Fast, versatile in-memory data structure store as cache, DB, or broker.
Scalable, fault-tolerant distributed NoSQL database for large data.
High-performance NoSQL document database with distributed architecture.
Powerful open-source relational database with extensibility and SQL standards.
Unified engine for large-scale data analytics.
No data |
No data |