feat(Agent): ChatAgent And AgentHub

1.LLM TongYiQianWen WenXinYiYan ZhiPu support
This commit is contained in:
yhjun1026 2023-10-13 15:46:07 +08:00
parent acf6854842
commit def77e66fe
21 changed files with 234 additions and 146 deletions

View File

@ -87,6 +87,11 @@ class CommandRegistry:
if hasattr(reloaded_module, "register"):
reloaded_module.register(self)
def is_valid_command(self, name:str)-> bool:
if name not in self.commands:
return False
else:
return True
def get_command(self, name: str) -> Callable[..., Any]:
return self.commands[name]
@ -229,6 +234,18 @@ class ApiCall:
return True
return False
def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
error_md_tags = ["```", "```python", "```xml", "```json", "```markdown"]
if include_end == False:
md_tag_end = ""
else:
md_tag_end = "```"
for tag in error_md_tags:
all_context = all_context.replace(tag + api_context + md_tag_end, api_context)
all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context)
all_context = all_context.replace(tag + api_context, api_context)
return all_context
def api_view_context(self, all_context: str, display_mode: bool = False):
error_mk_tags = ["```", "```python", "```xml"]
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
@ -237,22 +254,20 @@ class ApiCall:
if api_status is not None:
if display_mode:
if api_status.api_result:
for tag in error_mk_tags:
all_context = all_context.replace(tag + api_context + "```", api_context)
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, api_status.api_result)
else:
if api_status.status == Status.FAILED.value:
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
for tag in error_mk_tags:
all_context = all_context.replace(tag + api_context + "```", api_context)
all_context = self.__deal_error_md_tags(all_context, api_context)
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
else:
for tag in error_mk_tags:
all_context = all_context.replace(tag + api_context + "```", api_context)
all_context = self.__deal_error_md_tags(all_context, api_context, False)
all_context = all_context.replace(api_context, self.to_view_text(api_status))
else:
@ -260,13 +275,13 @@ class ApiCall:
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
for tag in error_mk_tags:
all_context = all_context.replace(tag + api_context , api_context)
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
return all_context
def update_from_context(self, all_context):
logging.info(f"from_context:{all_context}")
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
@ -308,7 +323,6 @@ class ApiCall:
return result.decode("utf-8")
def run(self, llm_text):
print(f"stream_plugin_call:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
@ -327,8 +341,6 @@ class ApiCall:
return self.api_view_context(llm_text)
def run_display_sql(self, llm_text, sql_run_func):
print(f"get_display_sql:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
@ -343,7 +355,7 @@ class ApiCall:
param = {
"df": sql_run_func(sql),
}
if self.display_registry.get_command(value.name):
if self.display_registry.is_valid_command(value.name):
value.api_result = self.display_registry.call(value.name, **param)
else:
value.api_result = self.display_registry.call("response_table", **param)

View File

@ -12,7 +12,7 @@ matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
from matplotlib.font_manager import FontManager
from pilot.common.string_utils import is_scientific_notation
import logging
logger = logging.getLogger(__name__)
@ -88,6 +88,14 @@ def zh_font_set():
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
def format_axis(value, pos):
# 判断是否为数字
if is_scientific_notation(value):
# 判断是否需要进行非科学计数法格式化
return '{:.2f}'.format(value)
return value
@command(
"response_line_chart",
@ -98,58 +106,61 @@ def response_line_chart( df: DataFrame) -> str:
logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data")
try:
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
# set font
# zh_font_set()
font_names = [
"Heiti TC",
"Songti SC",
"STHeiti Light",
"Microsoft YaHei",
"SimSun",
"SimHei",
"KaiTi",
]
fm = FontManager()
mat_fonts = set(f.name for f in fm.ttflist)
can_use_fonts = []
for font_name in font_names:
if font_name in mat_fonts:
can_use_fonts.append(font_name)
if len(can_use_fonts) > 0:
plt.rcParams["font.sans-serif"] = can_use_fonts
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
rc = {"font.sans-serif": can_use_fonts}
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
sns.set_palette("Set3") # 设置颜色主题
sns.set_style("dark")
sns.color_palette("hls", 10)
sns.hls_palette(8, l=0.5, s=0.7)
sns.set(context="notebook", style="ticks", rc=rc)
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
x, y, non_num_columns, num_colmns = data_pre_classification(df)
# ## 复杂折线图实现
if len(num_colmns) > 0:
num_colmns.append(y)
df_melted = pd.melt(
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
)
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
x, y, non_num_columns, num_colmns = data_pre_classification(df)
# ## 复杂折线图实现
if len(num_colmns) > 0:
num_colmns.append(y)
df_melted = pd.melt(
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
)
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
else:
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, dpi=100, transparent=True)
chart_name = "line_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
except Exception as e:
logging.error("Draw Line Chart Faild!" + str(e), e)
raise ValueError("Draw Line Chart Faild!" + str(e))
@command(
@ -230,12 +241,12 @@ def response_bar_chart( df: DataFrame) -> str:
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
# 设置 y 轴刻度格式为普通数字格式
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
plt.savefig(chart_path, dpi=100,transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@ -289,7 +300,7 @@ def response_pie_chart(df: DataFrame) -> str:
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
plt.savefig(chart_path, bbox_inches="tight", dpi=100, transparent=True)
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""

View File

@ -14,6 +14,6 @@ def response_table(df: DataFrame) -> str:
logger.info(f"response_table")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = html.replace("\n", " ")
return view_text
table_str = table_str.replace("\n", " ")
html = f""" \n<div class="w-full overflow-auto">{table_str}</div>\n """
return html

View File

@ -19,6 +19,17 @@ def is_chinese_include_number(text):
match = re.match(pattern, text)
return match is not None
def is_scientific_notation(string):
# 科学计数法的正则表达式
pattern = r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$'
# 使用正则表达式匹配字符串
match = re.match(pattern, str(string))
# 判断是否匹配成功
if match is not None:
return True
else:
return False
def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
match_map ={}

View File

@ -6,61 +6,70 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
logger = logging.getLogger(__name__)
def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
model: ProxyModel, tokenizer, params, device, context_len=2048
):
import dashscope
from dashscope import Generation
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
proxy_api_key = model_params.proxy_api_key
dashscope.api_key = proxy_api_key
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
role_define = messages.pop(0)
history.append({"role": "system", "content": role_define.content})
else:
message = messages.pop(0)
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "system", "content": message.content})
history.append({"role": "assistant", "content": message.content})
else:
pass
temp_his = history[::-1]
# temp_his = history[::-1]
temp_his = history
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
print(history)
gen = Generation()
res = gen.call(
proxyllm_backend,
messages=history,
top_p=params.get("top_p", 0.8),
stream=True,
result_format='message'
result_format='message'
)
for r in res:
if r["output"]["choices"][0]["message"].get("content") is not None:
content = r["output"]["choices"][0]["message"].get("content")
yield content
if r:
if r['status_code'] == 200:
content = r["output"]["choices"][0]["message"].get("content")
yield content
else:
content = r['code'] + ":" + r["message"]
yield content

View File

@ -46,24 +46,45 @@ def wenxin_generate_stream(
proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}'
if not access_token:
yield "Failed to get access token. please set the correct api_key and secret key."
messages: List[ModelMessage] = params["messages"]
yield "Failed to get access token. please set the correct api_key and secret key."
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
system = ""
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
role_define = messages.pop(0)
system = role_define.content
else:
message = messages.pop(0)
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# temp_his = history[::-1]
temp_his = history
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
payload = {
"messages": history,
"system": system,
"temperature": params.get("temperature"),
"stream": True
}

View File

@ -20,22 +20,40 @@ def zhipu_generate_stream(
import zhipuai
zhipuai.api_key = proxy_api_key
history = []
messages: List[ModelMessage] = params["messages"]
# Add history conversation
for message in messages:
system = ""
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
role_define = messages.pop(0)
system = role_define.content
else:
message = messages.pop(0)
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "user", "content": message.content})
# elif message.role == ModelMessageRoleType.HUMAN:
# history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# temp_his = history[::-1]
temp_his = history
last_user_input = None
for m in temp_his:
if m["role"] == "user":
last_user_input = m
break
if last_user_input:
history.remove(last_user_input)
history.append(last_user_input)
res = zhipuai.model_api.sse_invoke(
model=proxyllm_backend,
prompt=history,

View File

@ -26,22 +26,19 @@ User Questions:
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
_DEFAULT_TEMPLATE_ZH = """
请使用上述历史对话中的数据结构信息在满足约束条件下结合数据分析回答用户的问题
请使用上述历史对话中的数据结构信息在满足下面约束条件下结合数据分析回答用户的问题
约束条件:
1.请先输出你的分析思路内容再输出具体的数据分析结果其中数据分析结果部分确保用如下格式输出:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
1.请先输出你的分析思路内容再输出具体的数据分析结果如果有数据数据分析时请确保在输出的结果中包含如下格式内容:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
2.请确保数据分析结果格式的内容在整个回答中只出现一次,确保上述结构稳定[]部分内容替换为对应的值
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定则使用'Text'作为显示可用显示类型如下: {disply_type}
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定则使用'Text'作为显示可用数据展示方式如下: {disply_type}
4.SQL中需要使用的表名是: {table_name},请不要使用没在数据结构中的列名
5.优先使用数据分析的方式回答如果用户问题不涉及数据分析内容你可以按你的理解进行回答
6.请确保你的输出内容有良好排版输出内容均为普通markdown文本,不要用```或者```python这种标签来包围<api-call>的输出内容
用户问题{user_input}
请确保你的输出格式如下:
需要告诉用户的分析思路.
<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
"""
输出给用户的分析文本信息.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
用户问题{user_input}
"""
_DEFAULT_TEMPLATE = (

View File

@ -4,6 +4,9 @@ import duckdb
import os
import re
import sqlparse
import pandas as pd
import chardet
import pandas as pd
import numpy as np
from pyparsing import CaselessKeyword, Word, alphas, alphanums, delimitedList, Forward, Group, Optional,\
@ -18,6 +21,16 @@ def excel_colunm_format(old_name: str) -> str:
new_column = new_column.replace(" ", "_")
return new_column
def detect_encoding(file_path):
# 读取文件的二进制数据
with open(file_path, 'rb') as f:
data = f.read()
# 使用 chardet 来检测文件编码
result = chardet.detect(data)
encoding = result['encoding']
confidence = result['confidence']
return encoding, confidence
def add_quotes_ex(sql: str, column_names):
sql = sql.replace("`", '"')
@ -76,16 +89,6 @@ def parse_sql(sql):
def add_quotes_v2(sql: str, column_names):
pass
def add_quotes_v3(sql):
pattern = r'[一-鿿]+'
matches = re.findall(pattern, sql)
for match in matches:
sql = sql.replace(match, f'[{match}]')
return sql
def add_quotes(sql, column_names=[]):
sql = sql.replace("`", "")
sql = sql.replace("'", "")
@ -156,25 +159,26 @@ def process_identifier(identifier, column_names=[]):
# if identifier.has_alias():
# alias = identifier.get_alias()
# identifier.tokens[-1].value = '[' + alias + ']'
if identifier.tokens and identifier.value in column_names:
if hasattr(identifier, 'tokens') and identifier.value in column_names:
if is_chinese(identifier.value):
new_value = get_new_value(identifier.value)
identifier.value = new_value
identifier.normalized = new_value
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
else:
for token in identifier.tokens:
if isinstance(token, sqlparse.sql.Function):
process_function(token)
elif token.ttype in sqlparse.tokens.Name and is_chinese(token.value):
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
elif token.value in column_names and is_chinese(token.value):
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
if hasattr(identifier, 'tokens'):
for token in identifier.tokens:
if isinstance(token, sqlparse.sql.Function):
process_function(token)
elif token.ttype in sqlparse.tokens.Name :
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
elif token.value in column_names:
new_value = get_new_value(token.value)
token.value = new_value
token.normalized = new_value
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
def get_new_value(value):
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
@ -186,11 +190,11 @@ def process_function(function):
# 如果参数部分是一个标识符(字段名)
if isinstance(param, sqlparse.sql.Identifier):
# 判断是否需要替换字段值
if is_chinese(param.value):
# if is_chinese(param.value):
# 替换字段值
new_value = get_new_value(param.value)
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
new_value = get_new_value(param.value)
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
print(str(function))
def is_chinese(text):
@ -221,7 +225,8 @@ class ExcelReader:
def __init__(self, file_path):
file_name = os.path.basename(file_path)
file_name_without_extension = os.path.splitext(file_name)[0]
encoding, confidence = detect_encoding(file_path)
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
self.excel_file_name = file_name
self.extension = os.path.splitext(file_name)[1]
# read excel file
@ -232,9 +237,10 @@ class ExcelReader:
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
elif file_path.endswith(".csv"):
df_tmp = pd.read_csv(file_path)
df_tmp = pd.read_csv(file_path, encoding=encoding)
self.df = pd.read_csv(
file_path,
encoding = encoding,
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
)
else:
@ -272,7 +278,7 @@ class ExcelReader:
return colunms, results.fetchall()
except Exception as e:
logging.error("excel sql run error!", e)
raise ValueError(f"Data Query Exception!{sql}")
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
def get_df_by_sql_ex(self, sql):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -417,6 +417,9 @@ def default_requires():
"accelerate>=0.20.3",
"sentence-transformers",
"protobuf==3.20.3",
"zhipuai",
"dashscope",
"chardet"
]
setup_spec.extras["default"] += setup_spec.extras["framework"]
setup_spec.extras["default"] += setup_spec.extras["knowledge"]