chat with plugin bug fix

This commit is contained in:
yhjun1026 2023-06-01 15:32:01 +08:00
parent bacc31658e
commit 6648a671df
4 changed files with 10 additions and 12 deletions

View File

@ -47,7 +47,7 @@ class BaseOutputParser(ABC):
return code
# TODO 后续和模型绑定
def _parse_model_stream_resp(self, response, sep: str):
def _parse_model_stream_resp(self, response, sep: str, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
@ -56,9 +56,8 @@ class BaseOutputParser(ABC):
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL:
output = data["text"].strip()
if CFG.LLM_MODEL in ["vicuna", "guanaco"]:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
@ -97,7 +96,7 @@ class BaseOutputParser(ABC):
else:
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
def parse_model_server_out(self, response):
def parse_model_server_out(self, response, skip_echo_len: int = 0):
"""
parse the model server http response
Args:
@ -109,7 +108,7 @@ class BaseOutputParser(ABC):
if not self.is_stream_out:
return self._parse_model_nostream_resp(response, self.sep)
else:
return self._parse_model_stream_resp(response, self.sep)
return self._parse_model_stream_resp(response, self.sep, skip_echo_len)
def parse_prompt_response(self, model_out_text) -> T:
"""
@ -143,7 +142,7 @@ class BaseOutputParser(ABC):
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
return cleaned_output
def parse_view_response(self, ai_text) -> str:
def parse_view_response(self, ai_text, data) -> str:
"""
parse the ai response info to user view
Args:

View File

@ -128,6 +128,8 @@ class BaseChat(ABC):
def stream_call(self):
payload = self.__call_base()
skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 1
logger.info(f"Requert: \n{payload}")
ai_response_text = ""
try:
@ -139,7 +141,7 @@ class BaseChat(ABC):
timeout=120,
)
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len)
for resp_text_trunck in ai_response_text:
show_info = resp_text_trunck

View File

@ -15,7 +15,7 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
def parse_view_response(self, ai_text, data) -> str:
return ai_text["table"]
def get_format_instructions(self) -> str:

View File

@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
def parse_prompt_response(self, model_out_text) -> T:
return model_out_text
def parse_view_response(self, ai_text) -> str:
return super().parse_view_response(ai_text)
def get_format_instructions(self) -> str:
pass