diff --git a/pilot/base_modules/agent/commands/command_mange.py b/pilot/base_modules/agent/commands/command_mange.py index 180ccfe5a..7e1f48194 100644 --- a/pilot/base_modules/agent/commands/command_mange.py +++ b/pilot/base_modules/agent/commands/command_mange.py @@ -246,7 +246,14 @@ class ApiCall: return False def __deal_error_md_tags(self, all_context, api_context, include_end: bool = True): - error_md_tags = ["```", "```python", "```xml", "```json", "```markdown"] + error_md_tags = [ + "```", + "```python", + "```xml", + "```json", + "```markdown", + "```sql", + ] if include_end == False: md_tag_end = "" else: @@ -265,7 +272,6 @@ class ApiCall: return all_context def api_view_context(self, all_context: str, display_mode: bool = False): - error_mk_tags = ["```", "```python", "```xml"] call_context_map = extract_content_open_ending( all_context, self.agent_prefix, self.agent_end, True ) @@ -298,8 +304,8 @@ class ApiCall: now_time = datetime.now().timestamp() * 1000 cost = (now_time - self.start_time) / 1000 cost_str = "{:.2f}".format(cost) - for tag in error_mk_tags: - all_context = all_context.replace(tag + api_context, api_context) + all_context = self.__deal_error_md_tags(all_context, api_context) + all_context = all_context.replace( api_context, f'\nWaiting...{cost_str}S\n', @@ -377,8 +383,8 @@ class ApiCall: param["type"] = api_status.name if api_status.args: param["sql"] = api_status.args["sql"] - if api_status.err_msg: - param["err_msg"] = api_status.err_msg + # if api_status.err_msg: + # param["err_msg"] = api_status.err_msg if api_status.api_result: param["data"] = api_status.api_result @@ -448,33 +454,39 @@ class ApiCall: Returns: ChartView protocol text """ - if self.__is_need_wait_plugin_call(llm_text): - # wait api call generate complete - if self.check_last_plugin_call_ready(llm_text): - self.update_from_context(llm_text) - for key, value in self.plugin_status_map.items(): - if value.status == Status.TODO.value: - value.status = Status.RUNNING.value - logging.info(f"sql展示执行:{value.name},{value.args}") - try: - sql = value.args["sql"] - if sql is not None and len(sql) > 0: - data_df = sql_run_func(sql) - value.df = data_df - value.api_result = json.loads( - data_df.to_json( - orient="records", - date_format="iso", - date_unit="s", + try: + if self.__is_need_wait_plugin_call(llm_text): + # wait api call generate complete + if self.check_last_plugin_call_ready(llm_text): + self.update_from_context(llm_text) + for key, value in self.plugin_status_map.items(): + if value.status == Status.TODO.value: + value.status = Status.RUNNING.value + logging.info(f"sql展示执行:{value.name},{value.args}") + try: + sql = value.args["sql"] + if sql is not None and len(sql) > 0: + data_df = sql_run_func(sql) + value.df = data_df + value.api_result = json.loads( + data_df.to_json( + orient="records", + date_format="iso", + date_unit="s", + ) ) - ) - value.status = Status.COMPLETED.value - else: - value.status = Status.FAILED.value - value.err_msg = "No executable sql!" + value.status = Status.COMPLETED.value + else: + value.status = Status.FAILED.value + value.err_msg = "No executable sql!" + + except Exception as e: + value.status = Status.FAILED.value + value.err_msg = str(e) + value.end_time = datetime.now().timestamp() * 1000 + except Exception as e: + logging.error("Api parsing exception", e) + value.status = Status.FAILED.value + value.err_msg = "Api parsing exception," + str(e) - except Exception as e: - value.status = Status.FAILED.value - value.err_msg = str(e) - value.end_time = datetime.now().timestamp() * 1000 return self.api_view_context(llm_text, True) diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py index 8fd242882..33ab91b7d 100644 --- a/pilot/model/model_adapter.py +++ b/pilot/model/model_adapter.py @@ -132,6 +132,9 @@ class LLMModelAdaper: conv = conv.copy() system_messages = [] + user_messages = [] + ai_messages = [] + for message in messages: role, content = None, None if isinstance(message, ModelMessage): @@ -147,17 +150,30 @@ class LLMModelAdaper: # Support for multiple system messages system_messages.append(content) elif role == ModelMessageRoleType.HUMAN: - conv.append_message(conv.roles[0], content) + # conv.append_message(conv.roles[0], content) + user_messages.append(content) elif role == ModelMessageRoleType.AI: - conv.append_message(conv.roles[1], content) + # conv.append_message(conv.roles[1], content) + ai_messages.append(content) else: raise ValueError(f"Unknown role: {role}") + can_use_system = "" if system_messages: - if isinstance(conv, Conversation): - conv.set_system_message("".join(system_messages)) - else: - conv.update_system_message("".join(system_messages)) + # TODO vicuna 兼容 测试完放弃 + user_messages[-1] = system_messages[-1] + if len(system_messages) > 1: + can_use_system = system_messages[0] + + for i in range(len(user_messages)): + conv.append_message(conv.roles[0], user_messages[i]) + if i < len(ai_messages): + conv.append_message(conv.roles[1], ai_messages[i]) + + if isinstance(conv, Conversation): + conv.set_system_message(can_use_system) + else: + conv.update_system_message(can_use_system) # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) diff --git a/pilot/model/operator/model_operator.py b/pilot/model/operator/model_operator.py index 6486e8373..2f051377a 100644 --- a/pilot/model/operator/model_operator.py +++ b/pilot/model/operator/model_operator.py @@ -171,6 +171,8 @@ class ModelCacheBranchOperator(BranchOperator[Dict, Dict]): async def check_cache_true(input_value: Dict) -> bool: # Check if the cache contains the result for the given input + if not input_value["model_cache_enable"]: + return False cache_dict = _parse_cache_key_dict(input_value) cache_key: LLMCacheKey = self._client.new_key(**cache_dict) cache_value = await self._client.get(cache_key) diff --git a/pilot/model/proxy/llms/wenxin.py b/pilot/model/proxy/llms/wenxin.py index acc82907c..cfd47fd18 100644 --- a/pilot/model/proxy/llms/wenxin.py +++ b/pilot/model/proxy/llms/wenxin.py @@ -26,6 +26,41 @@ def _build_access_token(api_key: str, secret_key: str) -> str: return res.json().get("access_token") +def __convert_2_wenxin_messages(messages: List[ModelMessage]): + chat_round = 0 + wenxin_messages = [] + + last_usr_message = "" + system_messages = [] + + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + last_usr_message = message.content + elif message.role == ModelMessageRoleType.SYSTEM: + system_messages.append(message.content) + elif message.role == ModelMessageRoleType.AI: + last_ai_message = message.content + wenxin_messages.append({"role": "user", "content": last_usr_message}) + wenxin_messages.append({"role": "assistant", "content": last_ai_message}) + + # build last user messge + + if len(system_messages) > 0: + if len(system_messages) > 1: + end_message = system_messages[-1] + else: + last_message = messages[-1] + if last_message.role == ModelMessageRoleType.HUMAN: + end_message = system_messages[-1] + "\n" + last_message.content + else: + end_message = system_messages[-1] + else: + last_message = messages[-1] + end_message = last_message.content + wenxin_messages.append({"role": "user", "content": end_message}) + return wenxin_messages, system_messages + + def wenxin_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): @@ -40,8 +75,9 @@ def wenxin_generate_stream( if not model_version: yield f"Unsupport model version {model_name}" - proxy_api_key = model_params.proxy_api_key - proxy_api_secret = model_params.proxy_api_secret + keys: [] = model_params.proxy_api_key.split(";") + proxy_api_key = keys[0] + proxy_api_secret = keys[1] access_token = _build_access_token(proxy_api_key, proxy_api_secret) headers = {"Content-Type": "application/json", "Accept": "application/json"} @@ -51,40 +87,42 @@ def wenxin_generate_stream( if not access_token: yield "Failed to get access token. please set the correct api_key and secret key." - history = [] - messages: List[ModelMessage] = params["messages"] # Add history conversation + # system = "" + # if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM: + # role_define = messages.pop(0) + # system = role_define.content + # else: + # message = messages.pop(0) + # if message.role == ModelMessageRoleType.HUMAN: + # history.append({"role": "user", "content": message.content}) + # for message in messages: + # if message.role == ModelMessageRoleType.SYSTEM: + # history.append({"role": "user", "content": message.content}) + # # elif message.role == ModelMessageRoleType.HUMAN: + # # history.append({"role": "user", "content": message.content}) + # elif message.role == ModelMessageRoleType.AI: + # history.append({"role": "assistant", "content": message.content}) + # else: + # pass + # + # # temp_his = history[::-1] + # temp_his = history + # last_user_input = None + # for m in temp_his: + # if m["role"] == "user": + # last_user_input = m + # break + # + # if last_user_input: + # history.remove(last_user_input) + # history.append(last_user_input) + # + history, systems = __convert_2_wenxin_messages(messages) system = "" - if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM: - role_define = messages.pop(0) - system = role_define.content - else: - message = messages.pop(0) - if message.role == ModelMessageRoleType.HUMAN: - history.append({"role": "user", "content": message.content}) - for message in messages: - if message.role == ModelMessageRoleType.SYSTEM: - history.append({"role": "user", "content": message.content}) - # elif message.role == ModelMessageRoleType.HUMAN: - # history.append({"role": "user", "content": message.content}) - elif message.role == ModelMessageRoleType.AI: - history.append({"role": "assistant", "content": message.content}) - else: - pass - - # temp_his = history[::-1] - temp_his = history - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m - break - - if last_user_input: - history.remove(last_user_input) - history.append(last_user_input) - + if systems and len(systems) > 0: + system = systems[0] payload = { "messages": history, "system": system, diff --git a/pilot/model/proxy/llms/zhipu.py b/pilot/model/proxy/llms/zhipu.py index 89e7dd9a0..c5fabe8ed 100644 --- a/pilot/model/proxy/llms/zhipu.py +++ b/pilot/model/proxy/llms/zhipu.py @@ -8,6 +8,41 @@ from pilot.scene.base_message import ModelMessage, ModelMessageRoleType CHATGLM_DEFAULT_MODEL = "chatglm_pro" +def __convert_2_wenxin_messages(messages: List[ModelMessage]): + chat_round = 0 + wenxin_messages = [] + + last_usr_message = "" + system_messages = [] + + for message in messages: + if message.role == ModelMessageRoleType.HUMAN: + last_usr_message = message.content + elif message.role == ModelMessageRoleType.SYSTEM: + system_messages.append(message.content) + elif message.role == ModelMessageRoleType.AI: + last_ai_message = message.content + wenxin_messages.append({"role": "user", "content": last_usr_message}) + wenxin_messages.append({"role": "assistant", "content": last_ai_message}) + + # build last user messge + + if len(system_messages) > 0: + if len(system_messages) > 1: + end_message = system_messages[-1] + else: + last_message = messages[-1] + if last_message.role == ModelMessageRoleType.HUMAN: + end_message = system_messages[-1] + "\n" + last_message.content + else: + end_message = system_messages[-1] + else: + last_message = messages[-1] + end_message = last_message.content + wenxin_messages.append({"role": "user", "content": end_message}) + return wenxin_messages, system_messages + + def zhipu_generate_stream( model: ProxyModel, tokenizer, params, device, context_len=2048 ): @@ -22,40 +57,40 @@ def zhipu_generate_stream( import zhipuai zhipuai.api_key = proxy_api_key - history = [] messages: List[ModelMessage] = params["messages"] # Add history conversation - system = "" - if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM: - role_define = messages.pop(0) - system = role_define.content - else: - message = messages.pop(0) - if message.role == ModelMessageRoleType.HUMAN: - history.append({"role": "user", "content": message.content}) - for message in messages: - if message.role == ModelMessageRoleType.SYSTEM: - history.append({"role": "user", "content": message.content}) - # elif message.role == ModelMessageRoleType.HUMAN: - # history.append({"role": "user", "content": message.content}) - elif message.role == ModelMessageRoleType.AI: - history.append({"role": "assistant", "content": message.content}) - else: - pass - - # temp_his = history[::-1] - temp_his = history - last_user_input = None - for m in temp_his: - if m["role"] == "user": - last_user_input = m - break - - if last_user_input: - history.remove(last_user_input) - history.append(last_user_input) + # system = "" + # if len(messages) > 1 and messages[0].role == ModelMessageRoleType.SYSTEM: + # role_define = messages.pop(0) + # system = role_define.content + # else: + # message = messages.pop(0) + # if message.role == ModelMessageRoleType.HUMAN: + # history.append({"role": "user", "content": message.content}) + # for message in messages: + # if message.role == ModelMessageRoleType.SYSTEM: + # history.append({"role": "user", "content": message.content}) + # # elif message.role == ModelMessageRoleType.HUMAN: + # # history.append({"role": "user", "content": message.content}) + # elif message.role == ModelMessageRoleType.AI: + # history.append({"role": "assistant", "content": message.content}) + # else: + # pass + # + # # temp_his = history[::-1] + # temp_his = history + # last_user_input = None + # for m in temp_his: + # if m["role"] == "user": + # last_user_input = m + # break + # + # if last_user_input: + # history.remove(last_user_input) + # history.append(last_user_input) + history, systems = __convert_2_wenxin_messages(messages) res = zhipuai.model_api.sse_invoke( model=proxyllm_backend, prompt=history, diff --git a/pilot/openapi/api_v1/api_v1.py b/pilot/openapi/api_v1/api_v1.py index d3f7e06c5..bea74d0a5 100644 --- a/pilot/openapi/api_v1/api_v1.py +++ b/pilot/openapi/api_v1/api_v1.py @@ -333,7 +333,6 @@ def get_hist_messages(conv_uid: str): history_messages: List[OnceConversation] = history_mem.get_messages() if history_messages: for once in history_messages: - print(f"once:{once}") model_name = once.get("model_name", CFG.LLM_MODEL) once_message_vos = [ message2Vo(element, once["chat_order"], model_name) diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 9d65e5041..a7268ab68 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -119,7 +119,9 @@ class BaseOutputParser(ABC): print("un_stream ai response:", ai_response) return ai_response else: - raise ValueError("Model server error!code=" + resp_obj_ex["error_code"]) + raise ValueError( + f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}""" + ) def __illegal_json_ends(self, s): temp_json = s @@ -206,11 +208,16 @@ class BaseOutputParser(ABC): if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"): logger.info("illegal json processing:\n" + cleaned_output) cleaned_output = self.__extract_json(cleaned_output) + + if not cleaned_output or len(cleaned_output) <= 0: + return model_out_text + cleaned_output = ( cleaned_output.strip() .replace("\\n", " ") .replace("\n", " ") .replace("\\", " ") + .replace("\_", "_") ) cleaned_output = self.__illegal_json_ends(cleaned_output) return cleaned_output @@ -248,7 +255,9 @@ class BaseOutputParser(ABC): def _parse_model_response(response: ResponseTye): - if isinstance(response, ModelOutput): + if response is None: + resp_obj_ex = "" + elif isinstance(response, ModelOutput): resp_obj_ex = asdict(response) elif isinstance(response, str): resp_obj_ex = json.loads(response) diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 407010ca2..8925f5972 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -58,6 +58,7 @@ class BaseChat(ABC): chat_param["model_name"] if chat_param["model_name"] else CFG.LLM_MODEL ) self.llm_echo = False + self.model_cache_enable = chat_param.get("model_cache_enable", False) ### load prompt template # self.prompt_template: PromptTemplate = CFG.prompt_templates[ @@ -118,6 +119,9 @@ class BaseChat(ABC): def do_action(self, prompt_response): return prompt_response + def message_adjust(self): + pass + def get_llm_speak(self, prompt_define_response): if hasattr(prompt_define_response, "thoughts"): if isinstance(prompt_define_response.thoughts, dict): @@ -210,6 +214,7 @@ class BaseChat(ABC): "BaseChat.stream_call", metadata=self._get_span_metadata(payload) ) payload["span_id"] = span.span_id + payload["model_cache_enable"] = self.model_cache_enable try: async for output in await self._model_stream_operator.call_stream( call_data={"data": payload} @@ -243,6 +248,7 @@ class BaseChat(ABC): "BaseChat.nostream_call", metadata=self._get_span_metadata(payload) ) payload["span_id"] = span.span_id + payload["model_cache_enable"] = self.model_cache_enable try: with root_tracer.start_span("BaseChat.invoke_worker_manager.generate"): model_output = await self._model_operator.call( @@ -291,6 +297,8 @@ class BaseChat(ABC): view_message = view_message.replace("\n", "\\n") self.current_message.add_view_message(view_message) + self.message_adjust() + span.end() except Exception as e: print(traceback.format_exc()) @@ -307,15 +315,9 @@ class BaseChat(ABC): payload = await self.__call_base() logger.info(f"Request: \n{payload}") ai_response_text = "" + payload["model_cache_enable"] = self.model_cache_enable try: - from pilot.model.cluster import WorkerManagerFactory - - worker_manager = CFG.SYSTEM_APP.get_component( - ComponentType.WORKER_MANAGER_FACTORY, WorkerManagerFactory - ).create() - - model_output = await worker_manager.generate(payload) - + model_output = await self._model_operator.call(call_data={"data": payload}) ### output parse ai_response_text = ( self.prompt_template.output_parser.parse_model_nostream_resp( @@ -576,23 +578,23 @@ def _build_model_operator( ) -> BaseOperator: """Builds and returns a model processing workflow (DAG) operator. - This function constructs a Directed Acyclic Graph (DAG) for processing data using a model. - It includes caching and branching logic to either fetch results from a cache or process + This function constructs a Directed Acyclic Graph (DAG) for processing data using a model. + It includes caching and branching logic to either fetch results from a cache or process data using the model. It supports both streaming and non-streaming modes. .. code-block:: python input_node >> cache_check_branch_node cache_check_branch_node >> model_node >> save_cached_node >> join_node - cache_check_branch_node >> cached_node >> join_node + cache_check_branch_node >> cached_node >> join_node equivalent to:: - + -> model_node -> save_cached_node -> / \ input_node -> cache_check_branch_node ---> join_node - \ / + \ / -> cached_node ------------------- -> - + Args: is_stream (bool): Flag to determine if the operator should process data in streaming mode. dag_name (str): Name of the DAG. diff --git a/pilot/scene/chat_agent/chat.py b/pilot/scene/chat_agent/chat.py index 81af1b3b1..09968d8f3 100644 --- a/pilot/scene/chat_agent/chat.py +++ b/pilot/scene/chat_agent/chat.py @@ -64,7 +64,12 @@ class ChatAgent(BaseChat): return input_values def stream_plugin_call(self, text): - text = text.replace("\n", " ") + text = ( + text.replace("\\n", " ") + .replace("\n", " ") + .replace("\_", "_") + .replace("\\", " ") + ) with root_tracer.start_span( "ChatAgent.stream_plugin_call.api_call", metadata={"text": text} ): diff --git a/pilot/scene/chat_agent/prompt.py b/pilot/scene/chat_agent/prompt.py index 7d01bfd2f..94151544a 100644 --- a/pilot/scene/chat_agent/prompt.py +++ b/pilot/scene/chat_agent/prompt.py @@ -42,7 +42,8 @@ _DEFAULT_TEMPLATE_ZH = """ 3.根据上面约束的方式生成每个工具的调用,对于工具使用的提示文本,需要在工具使用前生成 4.如果用户目标无法理解和意图不明确,优先使用搜索引擎工具 5.参数内容可能需要根据用户的目标推理得到,不仅仅是从文本提取 - 6.约束条件和工具信息作为推理过程的辅助信息,不要表达在给用户的输出内容中 + 6.约束条件和工具信息作为推理过程的辅助信息,对应内容不要表达在给用户的输出内容中 + 7.不要把部分内容放在markdown标签里 {expand_constraints} 工具列表: diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py index daccff300..bede9bc40 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/chat.py @@ -119,7 +119,12 @@ class ChatExcel(BaseChat): return result def stream_plugin_call(self, text): - text = text.replace("\n", " ") + text = ( + text.replace("\\n", " ") + .replace("\n", " ") + .replace("\_", "_") + .replace("\\", " ") + ) with root_tracer.start_span( "ChatExcel.stream_plugin_call.run_display_sql", metadata={"text": text} ): diff --git a/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py b/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py index 641a845fc..0af8b3c3a 100644 --- a/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py +++ b/pilot/scene/chat_data/chat_excel/excel_analyze/prompt.py @@ -12,7 +12,7 @@ CFG = Config() _PROMPT_SCENE_DEFINE_EN = "You are a data analysis expert. " _DEFAULT_TEMPLATE_EN = """ -Please use the data structure information in the above historical dialogue and combine it with data analysis to answer the user's questions while satisfying the constraints. +Please use the data structure column analysis information generated in the above historical dialogue to answer the user's questions through duckdb sql data analysis under the following constraints.. Constraint: 1.Please fully understand the user's problem and use duckdb sql for analysis. The analysis content is returned in the output format required below. Please output the sql in the corresponding sql parameter. @@ -30,14 +30,14 @@ User Questions: _PROMPT_SCENE_DEFINE_ZH = """你是一个数据分析专家!""" _DEFAULT_TEMPLATE_ZH = """ -请使用上述历史对话中的数据结构信息,在满足下面约束条件下通过数据分析回答用户的问题。 +请使用历史对话中的数据结构信息,在满足下面约束条件下通过duckdb sql数据分析回答用户的问题。 约束条件: 1.请充分理解用户的问题,使用duckdb sql的方式进行分析, 分析内容按下面要求的输出格式返回,sql请输出在对应的sql参数中 2.请从如下给出的展示方式种选择最优的一种用以进行数据渲染,将类型名称放入返回要求格式的name参数值种,如果找不到最合适的则使用'Table'作为展示方式,可用数据展示方式如下: {disply_type} 3.SQL中需要使用的表名是: {table_name},请检查你生成的sql,不要使用没在数据结构中的列名,。 4.优先使用数据分析的方式回答,如果用户问题不涉及数据分析内容,你可以按你的理解进行回答 - 5.要求的输出格式中部分需要被代码解析执行,请确保这部分内容按要求输出 -请确保你的输出格式如下: + 5.要求的输出格式中部分需要被代码解析执行,请确保这部分内容按要求输出,不要参考历史信息的返回格式,请按下面要求返回 +请确保你的输出内容格式如下: 对用户说的想法摘要.[数据展示方式][正确的duckdb数据分析sql] 用户问题:{user_input} @@ -59,7 +59,7 @@ PROMPT_NEED_STREAM_OUT = True # Temperature is a configuration hyperparameter that controls the randomness of language model output. # A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. # For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. -PROMPT_TEMPERATURE = 0.8 +PROMPT_TEMPERATURE = 0.3 prompt = PromptTemplate( template_scene=ChatScene.ChatExcel.value(), diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py index ff4d243d7..a663e2756 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/chat.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/chat.py @@ -1,10 +1,7 @@ import json from typing import Any, Dict -from pilot.scene.base_message import ( - HumanMessage, - ViewMessage, -) +from pilot.scene.base_message import HumanMessage, ViewMessage, AIMessage from pilot.scene.base_chat import BaseChat from pilot.scene.base import ChatScene from pilot.common.sql_database import Database @@ -59,3 +56,14 @@ class ExcelLearning(BaseChat): "file_name": self.excel_reader.excel_file_name, } return input_values + + def message_adjust(self): + ### adjust learning result in messages + view_message = "" + for message in self.current_message.messages: + if message.type == ViewMessage.type: + view_message = message.content + + for message in self.current_message.messages: + if message.type == AIMessage.type: + message.content = view_message diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py index 8728cd744..b52558e47 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/out_parser.py @@ -36,40 +36,39 @@ class LearningExcelOutputParser(BaseOutputParser): return ExcelResponse(desciption=desciption, clounms=clounms, plans=plans) except Exception as e: logger.error(f"parse_prompt_response Faild!{str(e)}") - self.is_downgraded = True - return ExcelResponse( - desciption=model_out_text, clounms=self.data_schema, plans=None - ) + clounms = [] + for name in self.data_schema: + clounms.append({name: "-"}) + return ExcelResponse(desciption=model_out_text, clounms=clounms, plans=None) + + def __build_colunms_html(self, clounms_data): + html_colunms = f"### **Data Structure**\n" + column_index = 0 + for item in clounms_data: + column_index += 1 + keys = item.keys() + for key in keys: + html_colunms = ( + html_colunms + f"- **{column_index}.[{key}]** _{item[key]}_\n" + ) + return html_colunms + + def __build_plans_html(self, plans_data): + html_plans = f"### **Analysis plans**\n" + index = 0 + if plans_data: + for item in plans_data: + index += 1 + html_plans = html_plans + f"{item} \n" + return html_plans def parse_view_response(self, speak, data, prompt_response) -> str: if data and not isinstance(data, str): ### tool out data to table view html_title = f"### **Data Summary**\n{data.desciption} " - html_colunms = f"### **Data Structure**\n" - if self.is_downgraded: - column_index = 0 - for item in data.clounms: - column_index += 1 - html_colunms = ( - html_colunms + f"- **{column_index}.[{item}]** _未知_\n" - ) - else: - column_index = 0 - for item in data.clounms: - column_index += 1 - keys = item.keys() - for key in keys: - html_colunms = ( - html_colunms - + f"- **{column_index}.[{key}]** _{item[key]}_\n" - ) + html_colunms = self.__build_colunms_html(data.clounms) + html_plans = self.__build_plans_html(data.plans) - html_plans = f"### **Recommended analysis plan**\n" - index = 0 - if data.plans: - for item in data.plans: - index += 1 - html_plans = html_plans + f"{item} \n" html = f"""{html_title}\n{html_colunms}\n{html_plans}""" return html else: diff --git a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py index b2630bc46..e5722e554 100644 --- a/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py +++ b/pilot/scene/chat_data/chat_excel/excel_learning/prompt.py @@ -28,30 +28,27 @@ _DEFAULT_TEMPLATE_ZH = """ 下面是用户文件{file_name}的一部分数据,请学习理解该数据的结构和内容,按要求输出解析结果: {data_example} 分析各列数据的含义和作用,并对专业术语进行简单明了的解释, 如果是时间类型请给出时间格式类似:yyyy-MM-dd HH:MM:ss. +将列名作为属性名,分析解释作为属性值,组成json数组,并输出在返回json内容的ColumnAnalysis属性中. 请不要修改或者翻译列名,确保和给出数据列名一致. +针对数据从不同维度提供一些有用的分析思路给用户。 -提供一些分析方案思路,请一步一步思考。 - -请以JSON格式返回您的答案,返回格式如下: +请一步一步思考,确保只以JSON格式回答,具体格式如下: {response} """ _RESPONSE_FORMAT_SIMPLE_ZH = { "DataAnalysis": "数据内容分析总结", - "ColumnAnalysis": [{"column name1": "字段1介绍,专业术语解释(请尽量简单明了)"}], - "AnalysisProgram": ["1.分析方案1,图表展示方式1", "2.分析方案2,图表展示方式2"], + "ColumnAnalysis": [{"column name": "字段1介绍,专业术语解释(请尽量简单明了)"}], + "AnalysisProgram": ["1.分析方案1", "2.分析方案2"], } _RESPONSE_FORMAT_SIMPLE_EN = { "DataAnalysis": "Data content analysis summary", "ColumnAnalysis": [ { - "column name1": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)" + "column name": "Introduction to Column 1 and explanation of professional terms (please try to be as simple and clear as possible)" } ], - "AnalysisProgram": [ - "1. Analysis plan 1, chart display type 1", - "2. Analysis plan 2, chart display type 2", - ], + "AnalysisProgram": ["1. Analysis plan ", "2. Analysis plan "], } RESPONSE_FORMAT_SIMPLE = ( @@ -75,7 +72,7 @@ PROMPT_NEED_STREAM_OUT = False # Temperature is a configuration hyperparameter that controls the randomness of language model output. # A high temperature produces more unpredictable and creative results, while a low temperature produces more common and conservative output. # For example, if you adjust the temperature to 0.5, the model will usually generate text that is more predictable and less creative than if you set the temperature to 1.0. -PROMPT_TEMPERATURE = 0.5 +PROMPT_TEMPERATURE = 0.8 prompt = PromptTemplate( template_scene=ChatScene.ExcelLearning.value(), diff --git a/pilot/scene/chat_db/auto_execute/chat.py b/pilot/scene/chat_db/auto_execute/chat.py index 2467588cc..00ddce1a2 100644 --- a/pilot/scene/chat_db/auto_execute/chat.py +++ b/pilot/scene/chat_db/auto_execute/chat.py @@ -71,7 +71,8 @@ class ChatWithDbAutoExecute(BaseChat): ) input_values = { - # "input": self.current_user_input, + "db_name": self.db_name, + "user_input": self.current_user_input, "top_k": str(self.top_k), "dialect": self.database.dialect, "table_info": table_infos, diff --git a/pilot/scene/chat_db/auto_execute/out_parser.py b/pilot/scene/chat_db/auto_execute/out_parser.py index 9b485ce6b..1cd5765da 100644 --- a/pilot/scene/chat_db/auto_execute/out_parser.py +++ b/pilot/scene/chat_db/auto_execute/out_parser.py @@ -26,7 +26,7 @@ class DbChatOutputParser(BaseOutputParser): def __init__(self, sep: str, is_stream_out: bool): super().__init__(sep=sep, is_stream_out=is_stream_out) - def is_sql_statement(statement): + def is_sql_statement(self, statement): parsed = sqlparse.parse(statement) if not parsed: return False @@ -42,19 +42,26 @@ class DbChatOutputParser(BaseOutputParser): if self.is_sql_statement(clean_str): return SqlAction(clean_str, "") else: - response = json.loads(clean_str) - for key in sorted(response): - if key.strip() == "sql": - sql = response[key] - if key.strip() == "thoughts": - thoughts = response[key] - return SqlAction(sql, thoughts) + try: + response = json.loads(clean_str) + for key in sorted(response): + if key.strip() == "sql": + sql = response[key] + if key.strip() == "thoughts": + thoughts = response[key] + return SqlAction(sql, thoughts) + except Exception as e: + logging.error("json load faild") + return SqlAction("", clean_str) def parse_view_response(self, speak, data, prompt_response) -> str: param = {} api_call_element = ET.Element("chart-view") err_msg = None try: + if not prompt_response.sql or len(prompt_response.sql) <= 0: + return f"""{speak}""" + df = data(prompt_response.sql) param["type"] = "response_table" param["sql"] = prompt_response.sql diff --git a/pilot/scene/chat_db/auto_execute/prompt.py b/pilot/scene/chat_db/auto_execute/prompt.py index d4a46b9fd..ba1e01870 100644 --- a/pilot/scene/chat_db/auto_execute/prompt.py +++ b/pilot/scene/chat_db/auto_execute/prompt.py @@ -13,8 +13,11 @@ _PROMPT_SCENE_DEFINE_EN = "You are a database expert. " _PROMPT_SCENE_DEFINE_ZH = "你是一个数据库专家. " _DEFAULT_TEMPLATE_EN = """ -Please create a syntactically correct {dialect} sql based on the user question, use the following tables schema to generate sql: - {table_info} +Please answer the user's question based on the database selected by the user and some of the available table structure definitions of the database. +Database name: + {db_name} +Table structure definition: + {table_info} Constraint: 1.Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. @@ -22,6 +25,8 @@ Constraint: 3.Use as few tables as possible when querying. 4.Please check the correctness of the SQL and ensure that the query performance is optimized under correct conditions. +User Question: + {user_input} Please think step by step and respond according to the following JSON format: {response} Ensure the response is correct json and can be parsed by Python json.loads. @@ -29,15 +34,20 @@ Ensure the response is correct json and can be parsed by Python json.loads. """ _DEFAULT_TEMPLATE_ZH = """ -请根据用户输入问题,使用如下的表结构定义创建一个语法正确的 {dialect} sql: +请根据用户选择的数据库和该库的部分可用表结构定义来回答用户问题. +数据库名: + {db_name} +表结构定义: {table_info} 约束: + 1. 请理解用户意图根据用户输入问题,使用给出表结构定义创建一个语法正确的 {dialect} sql,如果不需要sql,则直接回答用户问题。 1. 除非用户在问题中指定了他希望获得的具体数据行数,否则始终将查询限制为最多 {top_k} 个结果。 2. 只能使用表结构信息中提供的表来生成 sql,如果无法根据提供的表结构中生成 sql ,请说:“提供的表结构信息不足以生成 sql 查询。” 禁止随意捏造信息。 3. 请注意生成SQL时不要弄错表和列的关系 4. 请检查SQL的正确性,并保证正确的情况下优化查询性能 - +用户问题: + {user_input} 请一步步思考并按照以下JSON格式回复: {response} 确保返回正确的json并且可以被Python json.loads方法解析. diff --git a/pilot/scene/chat_knowledge/refine_summary/prompt.py b/pilot/scene/chat_knowledge/refine_summary/prompt.py index 6898db5cd..ee39a480c 100644 --- a/pilot/scene/chat_knowledge/refine_summary/prompt.py +++ b/pilot/scene/chat_knowledge/refine_summary/prompt.py @@ -13,7 +13,9 @@ CFG = Config() PROMPT_SCENE_DEFINE = """A chat between a curious user and an artificial intelligence assistant, who very familiar with database related knowledge. The assistant gives helpful, detailed, professional and polite answers to the user's questions.""" -_DEFAULT_TEMPLATE_ZH = """根据提供的上下文信息,我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请根据你之前推理的内容进行最终的总结,并且总结回答的时候最好按照1.2.3.进行总结.""" +_DEFAULT_TEMPLATE_ZH = ( + """我们已经提供了一个到某一点的现有总结:{existing_answer}\n 请根据你之前推理的内容进行最终的总结,总结回答的时候最好按照1.2.3.进行.""" +) _DEFAULT_TEMPLATE_EN = """ We have provided an existing summary up to a certain point: {existing_answer}\nWe have the opportunity to refine the existing summary (only if needed) with some more context below. @@ -44,4 +46,4 @@ prompt = PromptTemplate( ) CFG.prompt_template_registry.register(prompt, is_default=True) -from ..v1 import prompt_chatglm \ No newline at end of file +from ..v1 import prompt_chatglm diff --git a/pilot/scene/chat_knowledge/summary/chat.py b/pilot/scene/chat_knowledge/summary/chat.py index 96e486ea8..7327b7a5b 100644 --- a/pilot/scene/chat_knowledge/summary/chat.py +++ b/pilot/scene/chat_knowledge/summary/chat.py @@ -21,9 +21,7 @@ class ExtractSummary(BaseChat): chat_param=chat_param, ) - # self.user_input = chat_param["current_user_input"] self.user_input = chat_param["select_param"] - # self.extract_mode = chat_param["select_param"] def generate_input_values(self): input_values = { diff --git a/pilot/scene/chat_knowledge/summary/out_parser.py b/pilot/scene/chat_knowledge/summary/out_parser.py index b4a81d9cf..cc4de7356 100644 --- a/pilot/scene/chat_knowledge/summary/out_parser.py +++ b/pilot/scene/chat_knowledge/summary/out_parser.py @@ -1,9 +1,7 @@ -import json import logging -import re from typing import List, Tuple -from pilot.out_parser.base import BaseOutputParser, T +from pilot.out_parser.base import BaseOutputParser, T, ResponseTye from pilot.configs.config import Config CFG = Config() @@ -26,28 +24,9 @@ class ExtractSummaryParser(BaseOutputParser): def parse_view_response(self, speak, data) -> str: ### tool out data to table view return data - def parse_model_nostream_resp(self, response: ResponseTye, sep: str) -> str: - ### tool out data to table view - resp_obj_ex = _parse_model_response(response) - if isinstance(resp_obj_ex, str): - resp_obj_ex = json.loads(resp_obj_ex) - if resp_obj_ex["error_code"] == 0: - all_text = resp_obj_ex["text"] - tmp_resp = all_text.split(sep) - last_index = -1 - for i in range(len(tmp_resp)): - if tmp_resp[i].find("assistant:") != -1: - last_index = i - ai_response = tmp_resp[last_index] - ai_response = ai_response.replace("assistant:", "") - ai_response = ai_response.replace("Assistant:", "") - ai_response = ai_response.replace("ASSISTANT:", "") - ai_response = ai_response.replace("\_", "_") - ai_response = ai_response.replace("\*", "*") - ai_response = ai_response.replace("\t", "") - ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ") - print("un_stream ai response:", ai_response) - return ai_response - else: - raise ValueError("Model server error!code=" + resp_obj_ex["error_code"]) + def parse_model_nostream_resp(self, response: ResponseTye, sep: str) -> str: + try: + return super().parse_model_nostream_resp(response, sep) + except Exception as e: + return str(e) diff --git a/pilot/scene/chat_knowledge/summary/prompt.py b/pilot/scene/chat_knowledge/summary/prompt.py index 224e7f073..404b9079f 100644 --- a/pilot/scene/chat_knowledge/summary/prompt.py +++ b/pilot/scene/chat_knowledge/summary/prompt.py @@ -50,4 +50,4 @@ prompt = PromptTemplate( ) CFG.prompt_template_registry.register(prompt, is_default=True) -from ..v1 import prompt_chatglm \ No newline at end of file +from ..v1 import prompt_chatglm diff --git a/pilot/server/knowledge/service.py b/pilot/server/knowledge/service.py index 7fa8c09e8..fac027920 100644 --- a/pilot/server/knowledge/service.py +++ b/pilot/server/knowledge/service.py @@ -541,15 +541,22 @@ class KnowledgeService: async def _llm_extract_summary( self, doc: str, conn_uid: str, model_name: str = None ): - """Extract triplets from text by llm""" + """Extract triplets from text by llm + Args: + doc: Document + conn_uid: str,chat conversation id + model_name: str, model name + Returns: + chat: BaseChat, refine summary chat. + """ from pilot.scene.base import ChatScene - import uuid chat_param = { "chat_session_id": conn_uid, "current_user_input": "", "select_param": doc, "model_name": model_name, + "model_cache_enable": False, } executor = CFG.SYSTEM_APP.get_component( ComponentType.EXECUTOR_DEFAULT, ExecutorFactory @@ -579,6 +586,8 @@ class KnowledgeService: model_name:model name str max_iteration:max iteration will call llm to summary concurrency_limit:the max concurrency threads to call llm + Returns: + Document: refine summary context document. """ from pilot.scene.base import ChatScene from pilot.common.chat_util import llm_chat_response_nostream @@ -595,6 +604,7 @@ class KnowledgeService: "current_user_input": "", "select_param": doc, "model_name": model_name, + "model_cache_enable": True, } tasks.append( llm_chat_response_nostream(