feat(Agent): ChatAgent And AgentHub

1.ChatExcel support stream response
2.ChatExcel support Multiple charts
This commit is contained in:
yhjun1026 2023-10-09 16:26:05 +08:00
parent ed702db9bc
commit f8de26b5ff
10 changed files with 270 additions and 192 deletions

View File

@ -1,16 +1,18 @@
import functools
import importlib
import inspect
import time
import json
import logging
import xml.etree.ElementTree as ET
from datetime import datetime
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, List
from pydantic import BaseModel
from pilot.base_modules.agent.common.schema import Status, ApiTagType
from pilot.base_modules.agent.commands.command import execute_command
from pilot.base_modules.agent.commands.generator import PluginPromptGenerator
from pilot.common.string_utils import extract_content_include, extract_content_open_ending, extract_content, extract_content_include_open_ending
from pilot.common.string_utils import extract_content_open_ending, extract_content
# Unique identifier for auto-gpt commands
AUTO_GPT_COMMAND_IDENTIFIER = "auto_gpt_command"
@ -165,107 +167,179 @@ def command(
return decorator
class PluginStatus(BaseModel):
name: str
location: List[int]
args: dict
status: Status = Status.TODO.value
logo_url: str = None
api_result: str = None
err_msg: str = None
start_time = datetime.now().timestamp() * 1000
end_time: int = None
class ApiCall:
agent_prefix = "<api-call>"
agent_end = "</api-call>"
name_prefix = "<name>"
name_end = "</name>"
def __init__(self, plugin_generator):
self.name: str = ""
self.status: Status = Status.TODO.value
self.logo_url: str = None
self.args = {}
self.api_result: str = None
self.err_msg: str = None
def __init__(self, plugin_generator: Any = None, display_registry: Any = None):
# self.name: str = ""
# self.status: Status = Status.TODO.value
# self.logo_url: str = None
# self.args = {}
# self.api_result: str = None
# self.err_msg: str = None
self.plugin_status_map = {}
self.plugin_generator = plugin_generator
self.display_registry = display_registry
self.start_time = datetime.now().timestamp() * 1000
def __repr__(self):
return f"ApiCall(name={self.name}, status={self.status}, args={self.args})"
def __is_need_wait_plugin_call(self, api_call_context):
start_agent_count = api_call_context.count(self.agent_prefix)
end_agent_count = api_call_context.count(self.agent_end)
if api_call_context.find(self.agent_prefix) >= 0:
if start_agent_count > 0:
return True
check_len = len(self.agent_prefix)
last_text = api_call_context[-check_len:]
for i in range(check_len):
text_tmp = last_text[-i:]
prefix_tmp = self.agent_prefix[:i]
if text_tmp == prefix_tmp:
return True
else:
i += 1
else:
# 末尾新出字符检测
check_len = len(self.agent_prefix)
last_text = api_call_context[-check_len:]
for i in range(check_len):
text_tmp = last_text[-i:]
prefix_tmp = self.agent_prefix[:i]
if text_tmp == prefix_tmp:
return True
else:
i += 1
return False
def __get_api_call_context(self, all_context):
return extract_content_include(all_context, self.agent_prefix, self.agent_end)
def __check_last_plugin_call_ready(self, all_context):
start_agent_count = all_context.count(self.agent_prefix)
end_agent_count = all_context.count(self.agent_end)
def __check_plugin_call_ready(self, all_context):
if all_context.find(self.agent_end) > 0:
if start_agent_count > 0 and start_agent_count == end_agent_count:
return True
return False
def api_view_context(self, all_context:str):
if all_context.find(self.agent_prefix) >= 0:
call_context = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end)
call_context_all = extract_content_include_open_ending(all_context, self.agent_prefix, self.agent_end)
if len(call_context) > 0:
name_context = extract_content(call_context, self.name_prefix, self.name_end)
if len(name_context) > 0:
self.name = name_context
return all_context.replace(call_context_all, self.to_view_text())
else:
return all_context
def api_view_context(self, all_context: str, display_mode: bool = False):
call_context_map = extract_content_open_ending(all_context, self.agent_prefix, self.agent_end, True)
for api_index, api_context in call_context_map.items():
api_status = self.plugin_status_map.get(api_context)
if api_status is not None:
if display_mode:
if api_status.api_result:
all_context = all_context.replace(api_context, api_status.api_result)
else:
if api_status.status == Status.FAILED.value:
all_context = all_context.replace(api_context, f"""\n<span style=\"color:red\">ERROR!</span>{api_status.err_msg}\n """)
else:
cost = (api_status.end_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = all_context.replace(api_context, f'\n<span style=\"color:green\">Waiting...{cost_str}S</span>\n')
else:
all_context = all_context.replace(api_context, self.to_view_text(api_status))
else:
# not ready api call view change
now_time = datetime.now().timestamp() * 1000
cost = (now_time - self.start_time) / 1000
cost_str = "{:.2f}".format(cost)
all_context = all_context.replace(api_context, f'\nWaiting... * {cost_str}S * \n')
return all_context
def update_from_context(self, all_context):
logging.info(f"from_context:{all_context}")
api_context = extract_content_include(all_context, self.agent_prefix, self.agent_end)
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
self.name = api_call_element.find('name').text
api_context_map = extract_content(all_context, self.agent_prefix, self.agent_end, True)
for api_index, api_context in api_context_map.items():
api_context = api_context.replace("\\n", "").replace("\n", "")
api_call_element = ET.fromstring(api_context)
api_name = api_call_element.find('name').text
api_args = {}
args_elements = api_call_element.find('args')
for child_element in args_elements.iter():
api_args[child_element.tag] = child_element.text
args_elements = api_call_element.find('args')
for child_element in args_elements.iter():
self.args[child_element.tag] = child_element.text
api_status = self.plugin_status_map.get(api_context)
if api_status is None:
api_status = PluginStatus(name=api_name, location=[api_index], args=api_args)
self.plugin_status_map[api_context] = api_status
else:
api_status.location.append(api_index)
def __to_view_param_str(self):
def __to_view_param_str(self, api_status):
param = {}
if self.name:
param['name'] = self.name
param['status'] = self.status
if self.logo_url:
param['logo'] = self.logo_url
if api_status.name:
param['name'] = api_status.name
param['status'] = api_status.status
if api_status.logo_url:
param['logo'] = api_status.logo_url
if self.err_msg:
param['err_msg'] = self.err_msg
if self.api_result:
param['result'] = self.api_result
if api_status.err_msg:
param['err_msg'] = api_status.err_msg
if api_status.api_result:
param['result'] = api_status.api_result
return json.dumps(param)
def to_view_text(self):
def to_view_text(self, api_status: PluginStatus):
api_call_element = ET.Element('dbgpt-view')
api_call_element.text = self.__to_view_param_str()
api_call_element.text = self.__to_view_param_str(api_status)
result = ET.tostring(api_call_element, encoding="utf-8")
return result.decode("utf-8")
def run(self, llm_text):
print(f"stream_plugin_call:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_plugin_call_ready(llm_text):
if self.__check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
if self.status == Status.TODO.value:
self.status = Status.RUNNING.value
logging.info(f"插件执行:{self.name},{self.args}")
try:
self.api_result = execute_command(self.name, self.args, self.plugin_generator)
self.status = Status.COMPLETED.value
except Exception as e:
self.status = Status.FAILED.value
self.err_msg = str(e)
return self.api_view_context(llm_text)
for key, value in self.plugin_status_map:
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"插件执行:{value.name},{value.args}")
try:
value.api_result = execute_command(value.name, value.args, self.plugin_generator)
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text)
def run_display_sql(self, llm_text, sql_run_func):
print(f"get_display_sql:{llm_text}")
if self.__is_need_wait_plugin_call(llm_text):
# wait api call generate complete
if self.__check_last_plugin_call_ready(llm_text):
self.update_from_context(llm_text)
for key, value in self.plugin_status_map.items():
if value.status == Status.TODO.value:
value.status = Status.RUNNING.value
logging.info(f"sql展示执行:{value.name},{value.args}")
try:
sql = value.args['sql']
if sql:
param = {
"df": sql_run_func(sql),
}
if self.display_registry.get_command(value.name):
value.api_result = self.display_registry.call(value.name, **param)
else:
value.api_result = self.display_registry.call("response_table", **param)
value.status = Status.COMPLETED.value
except Exception as e:
value.status = Status.FAILED.value
value.err_msg = str(e)
value.end_time = datetime.now().timestamp() * 1000
return self.api_view_context(llm_text, True)

View File

@ -92,10 +92,10 @@ def zh_font_set():
@command(
"response_line_chart",
"Line chart display, used to display comparative trend analysis data",
'"speak": "<speak>", "df":"<data frame>"',
'"df":"<data frame>"',
)
def response_line_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_line_chart:{speak},")
def response_line_chart( df: DataFrame) -> str:
logger.info(f"response_line_chart")
if df.size <= 0:
raise ValueError("No Data")
@ -148,17 +148,17 @@ def response_line_chart(speak: str, df: DataFrame) -> str:
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command(
"response_bar_chart",
"Histogram, suitable for comparative analysis of multiple target values",
'"speak": "<speak>", "df":"<data frame>"',
'"df":"<data frame>"',
)
def response_bar_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_bar_chart:{speak},")
def response_bar_chart( df: DataFrame) -> str:
logger.info(f"response_bar_chart")
if df.size <= 0:
raise ValueError("No Data")
@ -236,17 +236,17 @@ def response_bar_chart(speak: str, df: DataFrame) -> str:
chart_name = "bar_" + str(uuid.uuid1()) + ".png"
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img
@command(
"response_pie_chart",
"Pie chart, suitable for scenarios such as proportion and distribution statistics",
'"speak": "<speak>", "df":"<data frame>"',
'"df":"<data frame>"',
)
def response_pie_chart(speak: str, df: DataFrame) -> str:
logger.info(f"response_pie_chart:{speak},")
def response_pie_chart(df: DataFrame) -> str:
logger.info(f"response_pie_chart")
columns = df.columns.tolist()
if df.size <= 0:
raise ValueError("No Data")
@ -291,6 +291,6 @@ def response_pie_chart(speak: str, df: DataFrame) -> str:
chart_path = static_message_img_path + "/" + chart_name
plt.savefig(chart_path, bbox_inches="tight", dpi=100)
html_img = f"""<h5>{speak.replace("`", '"')}</h5><img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
html_img = f"""<img style='max-width: 100%; max-height: 70%;' src="/images/{chart_name}" />"""
return html_img

View File

@ -13,12 +13,12 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command(
"response_table",
"Table display, suitable for display with many display columns or non-numeric columns",
'"speak": "<speak>", "df":"<data frame>"',
'"df":"<data frame>"',
)
def response_table(speak: str, df: DataFrame) -> str:
logger.info(f"response_table:{speak}")
def response_table(df: DataFrame) -> str:
logger.info(f"response_table")
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
view_text = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
view_text = html.replace("\n", " ")
return view_text

View File

@ -12,10 +12,10 @@ logger = build_logger("show_table_gen", LOGDIR + "show_table_gen.log")
@command(
"response_data_text",
"Text display, the default display method, suitable for single-line or simple content display",
'"speak": "<speak>", "df":"<data frame>"',
'"df":"<data frame>"',
)
def response_data_text(speak: str, df: DataFrame) -> str:
logger.info(f"response_data_text:{speak}")
def response_data_text(df: DataFrame) -> str:
logger.info(f"response_data_text")
data = df.values
row_size = data.shape[0]
@ -25,7 +25,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
html_table = df.to_html(index=False, escape=False, sparsify=False)
table_str = "".join(html_table.split())
html = f"""<div class="w-full overflow-auto">{table_str}</div>"""
text_info = f"##### {str(speak)}" + "\n" + html.replace("\n", " ")
text_info = html.replace("\n", " ")
elif row_size == 1:
row = data[0]
for value in row:
@ -33,7 +33,7 @@ def response_data_text(speak: str, df: DataFrame) -> str:
value_str = value_str + f", ** {value} **"
else:
value_str = f" ** {value} **"
text_info = f"{speak}: {value_str}"
text_info = f" {value_str}"
else:
text_info = f"##### {speak}: _没有找到可用的数据_"
text_info = f"##### _没有找到可用的数据_"
return text_info

View File

@ -1,49 +1,47 @@
def extract_content(long_string, s1, s2):
def extract_content(long_string, s1, s2, is_include: bool = False):
# extract text
match_map ={}
start_index = long_string.find(s1)
if start_index < 0:
return ""
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1):end_index]
return extracted_content
while start_index != -1:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index:end_index + len(s2)]
else:
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1):end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index = long_string.find(s1, start_index + 1)
return match_map
def extract_content_open_ending(long_string, s1, s2):
def extract_content_open_ending(long_string, s1, s2, is_include: bool = False):
# extract text open ending
match_map = {}
start_index = long_string.find(s1)
if start_index < 0:
return ""
if long_string.find(s2) <=0:
end_index = len(long_string)
else:
end_index = long_string.find(s2, start_index + len(s1))
extracted_content = long_string[start_index + len(s1):end_index]
return extracted_content
def extract_content_include(long_string, s1, s2):
start_index = long_string.find(s1)
if start_index < 0:
return ""
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index:end_index + len(s2)]
return extracted_content
def extract_content_include_open_ending(long_string, s1, s2):
start_index = long_string.find(s1)
if start_index < 0:
return ""
if long_string.find(s2) <=0:
end_index = len(long_string)
else:
end_index = long_string.find(s2, start_index + len(s1) + 1)
extracted_content = long_string[start_index:end_index + len(s2)]
return extracted_content
while start_index != -1:
if long_string.find(s2, start_index) <=0:
end_index = len(long_string)
else:
if is_include:
end_index = long_string.find(s2, start_index + len(s1) + 1)
else:
end_index = long_string.find(s2, start_index + len(s1))
if is_include:
extracted_content = long_string[start_index:end_index + len(s2)]
else:
extracted_content = long_string[start_index + len(s1):end_index]
if extracted_content:
match_map[start_index] = extracted_content
start_index= long_string.find(s1, start_index + 1)
return match_map
if __name__=="__main__":
s = "abcd123efghijkjhhh456"
s = "abcd123efghijkjhhh456xxx123aa456yyy123bb456xx123"
s1 = "123"
s2 = "456"
print(extract_content_open_ending(s, s1, s2))
print(extract_content_open_ending(s, s1, s2, True))

View File

@ -34,7 +34,7 @@ class ChatAgent(BaseChat):
agent_module = CFG.SYSTEM_APP.get_componet(ComponetType.AGENT_HUB, ModuleAgent)
self.plugins_prompt_generator = agent_module.load_select_plugin(self.plugins_prompt_generator, self.select_plugins)
self.api_call = ApiCall(self.plugins_prompt_generator)
self.api_call = ApiCall(plugin_generator=self.plugins_prompt_generator)
def generate_input_values(self):
input_values = {

View File

@ -15,27 +15,6 @@ class PluginAction(NamedTuple):
class PluginChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
clean_json_str = super().parse_prompt_response(model_out_text)
print(clean_json_str)
if not clean_json_str:
raise ValueError("model server response not have json!")
try:
response = json.loads(clean_json_str)
except Exception as e:
raise ValueError("model server out not fllow the prompt!")
speak = ""
thoughts = ""
for key in sorted(response):
if key.strip() == "command":
command = response[key]
if key.strip() == "thoughts":
thoughts = response[key]
if key.strip() == "speak":
speak = response[key]
return PluginAction(command, speak, thoughts)
def parse_view_response(self, speak, data) -> str:
### tool out data to table view
print(f"parse_view_response:{speak},{str(data)}")

View File

@ -7,13 +7,13 @@ from pilot.scene.base_chat import BaseChat, logger
from pilot.scene.base import ChatScene
from pilot.common.sql_database import Database
from pilot.configs.config import Config
from pilot.base_modules.agent.commands.command_mange import ApiCall
from pilot.scene.chat_data.chat_excel.excel_analyze.prompt import prompt
from pilot.scene.chat_data.chat_excel.excel_reader import ExcelReader
from pilot.scene.chat_data.chat_excel.excel_learning.chat import ExcelLearning
from pilot.common.path_utils import has_path
from pilot.configs.model_config import LLM_MODEL_CONFIG, KNOWLEDGE_UPLOAD_ROOT_PATH
from pilot.base_modules.agent.common.schema import Status
CFG = Config()
@ -36,9 +36,18 @@ class ChatExcel(BaseChat):
KNOWLEDGE_UPLOAD_ROOT_PATH, chat_mode.value(), self.select_param
)
)
self.api_call = ApiCall(display_registry = CFG.command_disply)
super().__init__(chat_param=chat_param)
def _generate_numbered_list(self) -> str:
command_strings = []
if CFG.command_disply:
command_strings += [
str(item)
for item in CFG.command_disply.commands.values()
if item.enabled
]
return "\n".join(f"{i+1}. {item}" for i, item in enumerate(command_strings))
def generate_input_values(self):
input_values = {
@ -64,15 +73,25 @@ class ChatExcel(BaseChat):
result = await learn_chat.nostream_call()
return result
def do_action(self, prompt_response):
print(f"do_action:{prompt_response}")
def stream_plugin_call(self, text):
text = text.replace("\n", " ")
print(f"stream_plugin_call:{text}")
return self.api_call.run_display_sql(text, self.excel_reader.get_df_by_sql_ex)
# colunms, datas = self.excel_reader.run(prompt_response.sql)
param = {
"speak": prompt_response.thoughts,
"df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
}
if CFG.command_disply.get_command(prompt_response.display):
return CFG.command_disply.call(prompt_response.display, **param)
else:
return CFG.command_disply.call("response_table", **param)
# def do_action(self, prompt_response):
# print(f"do_action:{prompt_response}")
#
# # colunms, datas = self.excel_reader.run(prompt_response.sql)
#
#
# param = {
# "speak": prompt_response.thoughts,
# "df": self.excel_reader.get_df_by_sql_ex(prompt_response.sql),
# }
#
# if CFG.command_disply.get_command(prompt_response.display):
# return CFG.command_disply.call(prompt_response.display, **param)
# else:
# return CFG.command_disply.call("response_table", **param)

View File

@ -9,50 +9,52 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """
Please use the data structure information of the above historical dialogue, make sure not to use column names that are not in the data structure.
According to the user goal: {user_input}give the correct duckdb SQL for data analysis.
Use the table name: {table_name}
Please use the data structure and column information in the above historical dialogue and combine it with data analysis to answer the user's questions while satisfying the constraints.
According to the analysis SQL obtained by the user's goal, select the best one from the following display forms, if it cannot be determined, use Text as the display,Just need to return the type name into the result.
Display type:
{disply_type}
Respond in the following json format:
{response}
Ensure the response is correct json and can be parsed by Python json.loads
Constraint:
1.Please output your thinking process and analysis ideas first, and then output the specific data analysis results. The data analysis results are output in the following format:<api-call><name>display type</name><args><sql>Correct duckdb data analysis sql</sql></args></api-call>
2.For the available display methods of data analysis results, please choose the most appropriate one from the following display methods. If you are not sure, use 'response_data_text' as the display. The available display types are as follows:{disply_type}
3.The table name that needs to be used in SQL is: {table_name}, please make sure not to use column names that are not in the data structure.
4.Give priority to answering using data analysis. If the user's question does not involve data analysis, you can answer according to your understanding.
User Questions:
{user_input}
"""
_PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!"""
_DEFAULT_TEMPLATE_ZH = """
请使用上述历史对话中的数据结构和列信息根据用户目标{user_input}给出正确的duckdb SQL进行数据分析和问题回答
请确保不要使用不在数据结构中的列名
SQL中需要使用的表名是: {table_name}
请使用上述历史对话中的数据结构信息在满足约束条件下结合数据分析回答用户的问题
根据用户目标得到的分析SQL请从以下显示类型中选择最合适的一种用来展示结果数据如果无法确定则使用'Text'作为显示, 只需要将类型名称返回到结果中
显示类型如下:
{disply_type}
以以下 json 格式响应:
{response}
确保响应是正确的json,并且可以被Python的json.loads方法解析.
约束条件:
1.请先输出你的分析思路内容再输出具体的数据分析结果其中数据分析结果部分确保用如下格式输出:<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
2.请确保数据分析结果格式的内容在整个回答中只出现一次,确保上述结构稳定[]部分内容替换为对应的值
3.数据分析结果可用的展示方式请在下面的展示方式中选择最合适的一种,放入数据分析结果的name字段内如果无法确定则使用'Text'作为显示可用显示类型如下: {disply_type}
4.SQL中需要使用的表名是: {table_name},请确保不要使用不在数据结构中的列名
5.优先使用数据分析的方式回答如果用户问题不涉及数据分析内容你可以按你的理解进行回答
6.请确保你的输出内容有良好的markdown格式的排版
用户问题{user_input}
请确保你的输出格式如下:
需要告诉用户的分析思路.
<api-call><name>[数据展示方式]</name><args><sql>[正确的duckdb数据分析sql]</sql></args></api-call>
"""
RESPONSE_FORMAT_SIMPLE = {
"sql": "analysis SQL",
"thoughts": "Current thinking and value of data analysis",
"display": "display type name",
}
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
_PROMPT_SCENE_DEFINE = (
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value
PROMPT_NEED_NEED_STREAM_OUT = False
PROMPT_NEED_NEED_STREAM_OUT = True
# Temperature is a configuration hyperparameter that controls the randomness of language model output.
# A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output.
@ -62,8 +64,7 @@ PROMPT_TEMPERATURE = 0.8
prompt = PromptTemplate(
template_scene=ChatScene.ChatExcel.value(),
input_variables=["user_input", "table_name", "disply_type"],
response_format=json.dumps(RESPONSE_FORMAT_SIMPLE, ensure_ascii=False, indent=4),
template_define=PROMPT_SCENE_DEFINE,
template_define=_PROMPT_SCENE_DEFINE,
template=_DEFAULT_TEMPLATE,
stream_out=PROMPT_NEED_NEED_STREAM_OUT,
output_parser=ChatExcelOutputParser(

View File

@ -9,7 +9,7 @@ from pilot.common.schema import SeparatorStyle
CFG = Config()
PROMPT_SCENE_DEFINE = "You are a data analysis expert. "
_PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. "
_DEFAULT_TEMPLATE_EN = """
This is an example dataplease learn to understand the structure and content of this data:
@ -21,6 +21,8 @@ Please return your answer in JSON format, the return format is as follows:
{response}
"""
_PROMPT_SCENE_DEFINE_ZH = "你是一个数据分析专家. "
_DEFAULT_TEMPLATE_ZH = """
下面是一份示例数据请学习理解该数据的结构和内容:
{data_example}
@ -37,10 +39,15 @@ RESPONSE_FORMAT_SIMPLE = {
"AnalysisProgram": ["1.分析方案1图表展示方式1", "2.分析方案2图表展示方式2"],
}
_DEFAULT_TEMPLATE = (
_DEFAULT_TEMPLATE_EN if CFG.LANGUAGE == "en" else _DEFAULT_TEMPLATE_ZH
)
PROMPT_SCENE_DEFINE =(
_PROMPT_SCENE_DEFINE_EN if CFG.LANGUAGE == "en" else _PROMPT_SCENE_DEFINE_ZH
)
PROMPT_SEP = SeparatorStyle.SINGLE.value