mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-29 14:57:35 +00:00
feat(Agent): ChatAgent And AgentHub
1.LLM TongYiQianWen WenXinYiYan ZhiPu support
This commit is contained in:
parent
acf6854842
commit
def77e66fe
@ -87,6 +87,11 @@ class CommandRegistry:
|
||||
if hasattr(reloaded_module, "register"):
|
||||
reloaded_module.register(self)
|
||||
|
||||
def is_valid_command(self, name:str)-> bool:
|
||||
if name not in self.commands:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
def get_command(self, name: str) -> Callable[..., Any]:
|
||||
return self.commands[name]
|
||||
|
||||
@ -229,6 +234,18 @@ class ApiCall:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True):
|
||||
error_md_tags = ["```", "```python", "```xml", "```json", "```markdown"]
|
||||
if include_end == False:
|
||||
md_tag_end = ""
|
||||
else:
|
||||
md_tag_end = "```"
|
||||
for tag in error_md_tags:
|
||||
all_context = all_context.replace(tag + api_context + md_tag_end, api_context)
|
||||
all_context = all_context.replace(tag + "\n" +api_context + "\n" + md_tag_end, api_context)
|
||||
all_context = all_context.replace(tag + api_context, api_context)
|
||||
return all_context
|
||||
|
||||
def api_view_context(self, all_context: str, display_mode: bool = False):
|
||||
error_mk_tags = ["```", "```python", "```xml"]
|
||||
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
|
||||
@ -237,22 +254,20 @@ class ApiCall:
|
||||
if api_status is not None:
|
||||
if display_mode:
|
||||
if api_status.api_result:
|
||||
for tag in error_mk_tags:
|
||||
all_context = all_context.replace(tag + api_context + "```", api_context)
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
all_context = all_context.replace(api_context, api_status.api_result)
|
||||
|
||||
else:
|
||||
if api_status.status == Status.FAILED.value:
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
|
||||
else:
|
||||
cost = (api_status.end_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
for tag in error_mk_tags:
|
||||
all_context = all_context.replace(tag + api_context + "```", api_context)
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context)
|
||||
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
|
||||
else:
|
||||
for tag in error_mk_tags:
|
||||
all_context = all_context.replace(tag + api_context + "```", api_context)
|
||||
all_context = self.__deal_error_md_tags(all_context, api_context, False)
|
||||
all_context = all_context.replace(api_context, self.to_view_text(api_status))
|
||||
|
||||
else:
|
||||
@ -260,13 +275,13 @@ class ApiCall:
|
||||
now_time = datetime.now().timestamp() * 1000
|
||||
cost = (now_time - self.start_time) / 1000
|
||||
cost_str = "{:.2f}".format(cost)
|
||||
for tag in error_mk_tags:
|
||||
all_context = all_context.replace(tag + api_context , api_context)
|
||||
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
|
||||
|
||||
return all_context
|
||||
|
||||
def update_from_context(self, all_context):
|
||||
logging.info(f"from_context:{all_context}")
|
||||
|
||||
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
|
||||
for api_index, api_context in api_context_map.items():
|
||||
api_context = api_context.replace("\\n", "").replace("\n", "")
|
||||
@ -308,7 +323,6 @@ class ApiCall:
|
||||
return result.decode("utf-8")
|
||||
|
||||
def run(self, llm_text):
|
||||
print(f"stream_plugin_call:{llm_text}")
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.__check_last_plugin_call_ready(llm_text):
|
||||
@ -327,8 +341,6 @@ class ApiCall:
|
||||
return self.api_view_context(llm_text)
|
||||
|
||||
def run_display_sql(self, llm_text, sql_run_func):
|
||||
print(f"get_display_sql:{llm_text}")
|
||||
|
||||
if self.__is_need_wait_plugin_call(llm_text):
|
||||
# wait api call generate complete
|
||||
if self.__check_last_plugin_call_ready(llm_text):
|
||||
@ -343,7 +355,7 @@ class ApiCall:
|
||||
param = {
|
||||
"df": sql_run_func(sql),
|
||||
}
|
||||
if self.display_registry.get_command(value.name):
|
||||
if self.display_registry.is_valid_command(value.name):
|
||||
value.api_result = self.display_registry.call(value.name, **param)
|
||||
else:
|
||||
value.api_result = self.display_registry.call("response_table", **param)
|
||||
|
@ -12,7 +12,7 @@ matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.ticker as mtick
|
||||
from matplotlib.font_manager import FontManager
|
||||
|
||||
from pilot.common.string_utils import is_scientific_notation
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -88,6 +88,14 @@ def zh_font_set():
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
def format_axis(value, pos):
|
||||
# 判断是否为数字
|
||||
if is_scientific_notation(value):
|
||||
# 判断是否需要进行非科学计数法格式化
|
||||
|
||||
return '{:.2f}'.format(value)
|
||||
return value
|
||||
|
||||
|
||||
@command(
|
||||
"response_line_chart",
|
||||
@ -98,58 +106,61 @@ def response_line_chart( df: DataFrame) -> str:
|
||||
logger.info(f"response_line_chart")
|
||||
if df.size <= 0:
|
||||
raise ValueError("No Data!")
|
||||
try:
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
|
||||
# set font
|
||||
# zh_font_set()
|
||||
font_names = [
|
||||
"Heiti TC",
|
||||
"Songti SC",
|
||||
"STHeiti Light",
|
||||
"Microsoft YaHei",
|
||||
"SimSun",
|
||||
"SimHei",
|
||||
"KaiTi",
|
||||
]
|
||||
fm = FontManager()
|
||||
mat_fonts = set(f.name for f in fm.ttflist)
|
||||
can_use_fonts = []
|
||||
for font_name in font_names:
|
||||
if font_name in mat_fonts:
|
||||
can_use_fonts.append(font_name)
|
||||
if len(can_use_fonts) > 0:
|
||||
plt.rcParams["font.sans-serif"] = can_use_fonts
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
|
||||
rc = {"font.sans-serif": can_use_fonts}
|
||||
plt.rcParams["axes.unicode_minus"] = False # 解决无法显示符号的问题
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
|
||||
sns.set(font=can_use_fonts[0], font_scale=0.8) # 解决Seaborn中文显示问题
|
||||
sns.set_palette("Set3") # 设置颜色主题
|
||||
sns.set_style("dark")
|
||||
sns.color_palette("hls", 10)
|
||||
sns.hls_palette(8, l=0.5, s=0.7)
|
||||
sns.set(context="notebook", style="ticks", rc=rc)
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
# ## 复杂折线图实现
|
||||
if len(num_colmns) > 0:
|
||||
num_colmns.append(y)
|
||||
df_melted = pd.melt(
|
||||
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
|
||||
)
|
||||
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
|
||||
else:
|
||||
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)
|
||||
x, y, non_num_columns, num_colmns = data_pre_classification(df)
|
||||
# ## 复杂折线图实现
|
||||
if len(num_colmns) > 0:
|
||||
num_colmns.append(y)
|
||||
df_melted = pd.melt(
|
||||
df, id_vars=x, value_vars=num_colmns, var_name="line", value_name="Value"
|
||||
)
|
||||
sns.lineplot(data=df_melted, x=x, y="Value", hue="line", ax=ax, palette="Set2")
|
||||
else:
|
||||
sns.lineplot(data=df, x=x, y=y, ax=ax, palette="Set2")
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
|
||||
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
|
||||
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, dpi=100, transparent=True)
|
||||
|
||||
chart_name = "line_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
except Exception as e:
|
||||
logging.error("Draw Line Chart Faild!" + str(e), e)
|
||||
raise ValueError("Draw Line Chart Faild!" + str(e))
|
||||
|
||||
|
||||
@command(
|
||||
@ -230,12 +241,12 @@ def response_bar_chart( df: DataFrame) -> str:
|
||||
sns.barplot(data=df, x=x, y=y, hue=hue, palette="Set2", ax=ax)
|
||||
|
||||
# 设置 y 轴刻度格式为普通数字格式
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda y, _: "{:,.0f}".format(y)))
|
||||
ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
ax.yaxis.set_major_formatter(mtick.FuncFormatter(format_axis))
|
||||
# ax.xaxis.set_major_formatter(mtick.FuncFormatter(lambda x, _: "{:,.0f}".format(x)))
|
||||
|
||||
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
plt.savefig(chart_path, dpi=100,transparent=True)
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
return html_img
|
||||
|
||||
@ -289,7 +300,7 @@ def response_pie_chart(df: DataFrame) -> str:
|
||||
|
||||
chart_name = "pie_" + str(uuid.uuid1()) + ".png"
|
||||
chart_path = static_message_img_path + "/" + chart_name
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
|
||||
plt.savefig(chart_path, bbox_inches="tight", dpi=100, transparent=True)
|
||||
|
||||
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
|
||||
|
||||
|
@ -14,6 +14,6 @@ def response_table(df: DataFrame) -> str:
|
||||
logger.info(f"response_table")
|
||||
html_table = df.to_html(index=False, escape=False, sparsify=False)
|
||||
table_str = "".join(html_table.split())
|
||||
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
|
||||
view_text = html.replace("\n", " ")
|
||||
return view_text
|
||||
table_str = table_str.replace("\n", " ")
|
||||
html = f""" \n<div class="w-full overflow-auto">{table_str}</div>\n """
|
||||
return html
|
||||
|
@ -19,6 +19,17 @@ def is_chinese_include_number(text):
|
||||
match = re.match(pattern, text)
|
||||
return match is not None
|
||||
|
||||
def is_scientific_notation(string):
|
||||
# 科学计数法的正则表达式
|
||||
pattern = r'^[-+]?[0-9]*\.?[0-9]+([eE][-+]?[0-9]+)?$'
|
||||
# 使用正则表达式匹配字符串
|
||||
match = re.match(pattern, str(string))
|
||||
# 判断是否匹配成功
|
||||
if match is not None:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def extract_content(long_string, s1, s2, is_include: bool = False):
|
||||
# extract text
|
||||
match_map ={}
|
||||
|
@ -6,61 +6,70 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tongyi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
dashscope.api_key = proxy_api_key
|
||||
|
||||
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
if not proxyllm_backend:
|
||||
proxyllm_backend = Generation.Models.qwen_turbo # By Default qwen_turbo
|
||||
|
||||
|
||||
history = []
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
|
||||
for message in messages:
|
||||
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
role_define = messages.pop(0)
|
||||
history.append({"role": "system", "content": role_define.content})
|
||||
else:
|
||||
message = messages.pop(0)
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
temp_his = history[::-1]
|
||||
|
||||
# temp_his = history[::-1]
|
||||
temp_his = history
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
|
||||
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
print(history)
|
||||
|
||||
gen = Generation()
|
||||
res = gen.call(
|
||||
proxyllm_backend,
|
||||
messages=history,
|
||||
top_p=params.get("top_p", 0.8),
|
||||
stream=True,
|
||||
result_format='message'
|
||||
result_format='message'
|
||||
)
|
||||
|
||||
|
||||
for r in res:
|
||||
if r["output"]["choices"][0]["message"].get("content") is not None:
|
||||
content = r["output"]["choices"][0]["message"].get("content")
|
||||
yield content
|
||||
|
||||
if r:
|
||||
if r['status_code'] == 200:
|
||||
content = r["output"]["choices"][0]["message"].get("content")
|
||||
yield content
|
||||
else:
|
||||
content = r['code'] + ":" + r["message"]
|
||||
yield content
|
||||
|
@ -46,24 +46,45 @@ def wenxin_generate_stream(
|
||||
proxy_server_url = f'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{model_version}?access_token={access_token}'
|
||||
|
||||
if not access_token:
|
||||
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
yield "Failed to get access token. please set the correct api_key and secret key."
|
||||
|
||||
history = []
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
for message in messages:
|
||||
system = ""
|
||||
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
role_define = messages.pop(0)
|
||||
system = role_define.content
|
||||
else:
|
||||
message = messages.pop(0)
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
# temp_his = history[::-1]
|
||||
temp_his = history
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
payload = {
|
||||
"messages": history,
|
||||
"system": system,
|
||||
"temperature": params.get("temperature"),
|
||||
"stream": True
|
||||
}
|
||||
|
@ -20,22 +20,40 @@ def zhipu_generate_stream(
|
||||
|
||||
import zhipuai
|
||||
zhipuai.api_key = proxy_api_key
|
||||
|
||||
history = []
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
# Add history conversation
|
||||
|
||||
for message in messages:
|
||||
system = ""
|
||||
if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM:
|
||||
role_define = messages.pop(0)
|
||||
system = role_define.content
|
||||
else:
|
||||
message = messages.pop(0)
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "system", "content": message.content})
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
history.append({"role": "user", "content": message.content})
|
||||
# elif message.role == ModelMessageRoleType.HUMAN:
|
||||
# history.append({"role": "user", "content": message.content})
|
||||
elif message.role == ModelMessageRoleType.AI:
|
||||
history.append({"role": "assistant", "content": message.content})
|
||||
else:
|
||||
pass
|
||||
|
||||
# temp_his = history[::-1]
|
||||
temp_his = history
|
||||
last_user_input = None
|
||||
for m in temp_his:
|
||||
if m["role"] == "user":
|
||||
last_user_input = m
|
||||
break
|
||||
|
||||
if last_user_input:
|
||||
history.remove(last_user_input)
|
||||
history.append(last_user_input)
|
||||
|
||||
res = zhipuai.model_api.sse_invoke(
|
||||
model=proxyllm_backend,
|
||||
prompt=history,
|
||||
|
@ -26,22 +26,19 @@ User Questions:
|
||||
|
||||
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
|
||||
_DEFAULT_TEMPLATE_ZH = """
|
||||
请使用上述历史对话中的数据结构信息,在满足约束条件下结合数据分析回答用户的问题。
|
||||
|
||||
请使用上述历史对话中的数据结构信息,在满足下面约束条件下结合数据分析回答用户的问题。
|
||||
约束条件:
|
||||
1.请先输出你的分析思路内容,再输出具体的数据分析结果,其中数据分析结果部分确保用如下格式输出:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
1.请先输出你的分析思路内容,再输出具体的数据分析结果。如果有数据数据分析时,请确保在输出的结果中包含如下格式内容:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
2.请确保数据分析结果格式的内容在整个回答中只出现一次,确保上述结构稳定,把[]部分内容替换为对应的值
|
||||
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定,则使用'Text'作为显示,可用显示类型如下: {disply_type}
|
||||
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定,则使用'Text'作为显示,可用数据展示方式如下: {disply_type}
|
||||
4.SQL中需要使用的表名是: {table_name},请不要使用没在数据结构中的列名。
|
||||
5.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答
|
||||
6.请确保你的输出内容有良好排版,输出内容均为普通markdown文本,不要用```或者```python这种标签来包围<api-call>的输出内容
|
||||
|
||||
用户问题:{user_input}
|
||||
请确保你的输出格式如下:
|
||||
需要告诉用户的分析思路.
|
||||
<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
"""
|
||||
输出给用户的分析文本信息.<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
|
||||
|
||||
用户问题:{user_input}
|
||||
"""
|
||||
|
||||
|
||||
_DEFAULT_TEMPLATE = (
|
||||
|
@ -4,6 +4,9 @@ import duckdb
|
||||
import os
|
||||
import re
|
||||
import sqlparse
|
||||
|
||||
import pandas as pd
|
||||
import chardet
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pyparsing import CaselessKeyword, Word, alphas, alphanums, delimitedList, Forward, Group, Optional,\
|
||||
@ -18,6 +21,16 @@ def excel_colunm_format(old_name: str) -> str:
|
||||
new_column = new_column.replace(" ", "_")
|
||||
return new_column
|
||||
|
||||
def detect_encoding(file_path):
|
||||
# 读取文件的二进制数据
|
||||
with open(file_path, 'rb') as f:
|
||||
data = f.read()
|
||||
# 使用 chardet 来检测文件编码
|
||||
result = chardet.detect(data)
|
||||
encoding = result['encoding']
|
||||
confidence = result['confidence']
|
||||
return encoding, confidence
|
||||
|
||||
|
||||
def add_quotes_ex(sql: str, column_names):
|
||||
sql = sql.replace("`", '"')
|
||||
@ -76,16 +89,6 @@ def parse_sql(sql):
|
||||
|
||||
|
||||
|
||||
def add_quotes_v2(sql: str, column_names):
|
||||
pass
|
||||
|
||||
def add_quotes_v3(sql):
|
||||
pattern = r'[一-鿿]+'
|
||||
matches = re.findall(pattern, sql)
|
||||
for match in matches:
|
||||
sql = sql.replace(match, f'[{match}]')
|
||||
return sql
|
||||
|
||||
def add_quotes(sql, column_names=[]):
|
||||
sql = sql.replace("`", "")
|
||||
sql = sql.replace("'", "")
|
||||
@ -156,25 +159,26 @@ def process_identifier(identifier, column_names=[]):
|
||||
# if identifier.has_alias():
|
||||
# alias = identifier.get_alias()
|
||||
# identifier.tokens[-1].value = '[' + alias + ']'
|
||||
if identifier.tokens and identifier.value in column_names:
|
||||
if hasattr(identifier, 'tokens') and identifier.value in column_names:
|
||||
if is_chinese(identifier.value):
|
||||
new_value = get_new_value(identifier.value)
|
||||
identifier.value = new_value
|
||||
identifier.normalized = new_value
|
||||
identifier.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
else:
|
||||
for token in identifier.tokens:
|
||||
if isinstance(token, sqlparse.sql.Function):
|
||||
process_function(token)
|
||||
elif token.ttype in sqlparse.tokens.Name and is_chinese(token.value):
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
elif token.value in column_names and is_chinese(token.value):
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
if hasattr(identifier, 'tokens'):
|
||||
for token in identifier.tokens:
|
||||
if isinstance(token, sqlparse.sql.Function):
|
||||
process_function(token)
|
||||
elif token.ttype in sqlparse.tokens.Name :
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
elif token.value in column_names:
|
||||
new_value = get_new_value(token.value)
|
||||
token.value = new_value
|
||||
token.normalized = new_value
|
||||
token.tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
def get_new_value(value):
|
||||
return f""" "{value.replace("`", "").replace("'", "").replace('"', "")}" """
|
||||
|
||||
@ -186,11 +190,11 @@ def process_function(function):
|
||||
# 如果参数部分是一个标识符(字段名)
|
||||
if isinstance(param, sqlparse.sql.Identifier):
|
||||
# 判断是否需要替换字段值
|
||||
if is_chinese(param.value):
|
||||
# if is_chinese(param.value):
|
||||
# 替换字段值
|
||||
new_value = get_new_value(param.value)
|
||||
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
|
||||
function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
new_value = get_new_value(param.value)
|
||||
# new_parameter = sqlparse.sql.Identifier(f'[{param.value}]')
|
||||
function_params[i].tokens = [sqlparse.sql.Token(sqlparse.tokens.Name, new_value)]
|
||||
print(str(function))
|
||||
|
||||
def is_chinese(text):
|
||||
@ -221,7 +225,8 @@ class ExcelReader:
|
||||
def __init__(self, file_path):
|
||||
file_name = os.path.basename(file_path)
|
||||
file_name_without_extension = os.path.splitext(file_name)[0]
|
||||
|
||||
encoding, confidence = detect_encoding(file_path)
|
||||
logging.error(f"Detected Encoding: {encoding} (Confidence: {confidence})")
|
||||
self.excel_file_name = file_name
|
||||
self.extension = os.path.splitext(file_name)[1]
|
||||
# read excel file
|
||||
@ -232,9 +237,10 @@ class ExcelReader:
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
elif file_path.endswith(".csv"):
|
||||
df_tmp = pd.read_csv(file_path)
|
||||
df_tmp = pd.read_csv(file_path, encoding=encoding)
|
||||
self.df = pd.read_csv(
|
||||
file_path,
|
||||
encoding = encoding,
|
||||
converters={i: csv_colunm_foramt for i in range(df_tmp.shape[1])},
|
||||
)
|
||||
else:
|
||||
@ -272,7 +278,7 @@ class ExcelReader:
|
||||
return colunms, results.fetchall()
|
||||
except Exception as e:
|
||||
logging.error("excel sql run error!", e)
|
||||
raise ValueError(f"Data Query Exception!{sql}")
|
||||
raise ValueError(f"Data Query Exception!\\nSQL[{sql}].\\nError:{str(e)}")
|
||||
|
||||
|
||||
def get_df_by_sql_ex(self, sql):
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user