chat with plugin bug fix

This commit is contained in:
yhjun1026 2023-06-01 15:32:01 +08:00
parent bacc31658e
commit 6648a671df
4 changed files with 10 additions and 12 deletions

View File

@ -47,7 +47,7 @@ class BaseOutputParser(ABC):
return code return code
# TODO 后续和模型绑定 # TODO 后续和模型绑定
def _parse_model_stream_resp(self, response, sep: str): def _parse_model_stream_resp(self, response, sep: str, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk: if chunk:
@ -56,9 +56,8 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode. """ TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
""" """
if data["error_code"] == 0: if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL: if CFG.LLM_MODEL in ["vicuna", "guanaco"]:
output = data["text"][skip_echo_len:].strip()
output = data["text"].strip()
else: else:
output = data["text"].strip() output = data["text"].strip()
@ -97,7 +96,7 @@ class BaseOutputParser(ABC):
else: else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"]) raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_model_server_out(self, response): def parse_model_server_out(self, response, skip_echo_len: int = 0):
""" """
parse the model server http response parse the model server http response
Args: Args:
@ -109,7 +108,7 @@ class BaseOutputParser(ABC):
if not self.is_stream_out: if not self.is_stream_out:
return self._parse_model_nostream_resp(response, self.sep) return self._parse_model_nostream_resp(response, self.sep)
else: else:
return self._parse_model_stream_resp(response, self.sep) return self._parse_model_stream_resp(response, self.sep, skip_echo_len)
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
""" """
@ -143,7 +142,7 @@ class BaseOutputParser(ABC):
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '') cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
return cleaned_output return cleaned_output
def parse_view_response(self, ai_text) -> str: def parse_view_response(self, ai_text, data) -> str:
""" """
parse the ai response info to user view parse the ai response info to user view
Args: Args:

View File

@ -128,6 +128,8 @@ class BaseChat(ABC):
def stream_call(self): def stream_call(self):
payload = self.__call_base() payload = self.__call_base()
skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 1
logger.info(f"Requert: \n{payload}") logger.info(f"Requert: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
@ -139,7 +141,7 @@ class BaseChat(ABC):
timeout=120, timeout=120,
) )
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response) ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len)
for resp_text_trunck in ai_response_text: for resp_text_trunck in ai_response_text:
show_info = resp_text_trunck show_info = resp_text_trunck

View File

@ -15,7 +15,7 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text
def parse_view_response(self, ai_text) -> str: def parse_view_response(self, ai_text, data) -> str:
return ai_text["table"] return ai_text["table"]
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:

View File

@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T: def parse_prompt_response(self, model_out_text) -> T:
return model_out_text return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
pass pass