diff --git a/pilot/out_parser/base.py b/pilot/out_parser/base.py index 9beb8b5f5..eb992c8eb 100644 --- a/pilot/out_parser/base.py +++ b/pilot/out_parser/base.py @@ -47,7 +47,7 @@ class BaseOutputParser(ABC): return code # TODO 后续和模型绑定 - def _parse_model_stream_resp(self, response, sep: str): + def _parse_model_stream_resp(self, response, sep: str, skip_echo_len): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: @@ -56,9 +56,8 @@ class BaseOutputParser(ABC): """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ if data["error_code"] == 0: - if "vicuna" in CFG.LLM_MODEL: - - output = data["text"].strip() + if CFG.LLM_MODEL in ["vicuna", "guanaco"]: + output = data["text"][skip_echo_len:].strip() else: output = data["text"].strip() @@ -97,7 +96,7 @@ class BaseOutputParser(ABC): else: raise ValueError("Model server error!code=" + respObj_ex["error_code"]) - def parse_model_server_out(self, response): + def parse_model_server_out(self, response, skip_echo_len: int = 0): """ parse the model server http response Args: @@ -109,7 +108,7 @@ class BaseOutputParser(ABC): if not self.is_stream_out: return self._parse_model_nostream_resp(response, self.sep) else: - return self._parse_model_stream_resp(response, self.sep) + return self._parse_model_stream_resp(response, self.sep, skip_echo_len) def parse_prompt_response(self, model_out_text) -> T: """ @@ -143,7 +142,7 @@ class BaseOutputParser(ABC): cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') return cleaned_output - def parse_view_response(self, ai_text) -> str: + def parse_view_response(self, ai_text, data) -> str: """ parse the ai response info to user view Args: diff --git a/pilot/scene/base_chat.py b/pilot/scene/base_chat.py index 650235d63..e1d178a12 100644 --- a/pilot/scene/base_chat.py +++ b/pilot/scene/base_chat.py @@ -128,6 +128,8 @@ class BaseChat(ABC): def stream_call(self): payload = self.__call_base() + + skip_echo_len = len(payload.get('prompt').replace("", " ")) + 1 logger.info(f"Requert: \n{payload}") ai_response_text = "" try: @@ -139,7 +141,7 @@ class BaseChat(ABC): timeout=120, ) - ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) + ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len) for resp_text_trunck in ai_response_text: show_info = resp_text_trunck diff --git a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py index 0d2a7e49d..b17571edd 100644 --- a/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py +++ b/pilot/scene/chat_knowledge/inner_db_summary/out_parser.py @@ -15,7 +15,7 @@ class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text - def parse_view_response(self, ai_text) -> str: + def parse_view_response(self, ai_text, data) -> str: return ai_text["table"] def get_format_instructions(self) -> str: diff --git a/pilot/scene/chat_normal/out_parser.py b/pilot/scene/chat_normal/out_parser.py index 0f7ccd791..0b8277d63 100644 --- a/pilot/scene/chat_normal/out_parser.py +++ b/pilot/scene/chat_normal/out_parser.py @@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser): def parse_prompt_response(self, model_out_text) -> T: return model_out_text - def parse_view_response(self, ai_text) -> str: - return super().parse_view_response(ai_text) - def get_format_instructions(self) -> str: pass