mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-13 22:15:35 +00:00
chat with plugin bug fix
This commit is contained in:
parent
bacc31658e
commit
6648a671df
@ -47,7 +47,7 @@ class BaseOutputParser(ABC):
|
||||
return code
|
||||
|
||||
# TODO 后续和模型绑定
|
||||
def _parse_model_stream_resp(self, response, sep: str):
|
||||
def _parse_model_stream_resp(self, response, sep: str, skip_echo_len):
|
||||
|
||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||
if chunk:
|
||||
@ -56,9 +56,8 @@ class BaseOutputParser(ABC):
|
||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||
"""
|
||||
if data["error_code"] == 0:
|
||||
if "vicuna" in CFG.LLM_MODEL:
|
||||
|
||||
output = data["text"].strip()
|
||||
if CFG.LLM_MODEL in ["vicuna", "guanaco"]:
|
||||
output = data["text"][skip_echo_len:].strip()
|
||||
else:
|
||||
output = data["text"].strip()
|
||||
|
||||
@ -97,7 +96,7 @@ class BaseOutputParser(ABC):
|
||||
else:
|
||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||
|
||||
def parse_model_server_out(self, response):
|
||||
def parse_model_server_out(self, response, skip_echo_len: int = 0):
|
||||
"""
|
||||
parse the model server http response
|
||||
Args:
|
||||
@ -109,7 +108,7 @@ class BaseOutputParser(ABC):
|
||||
if not self.is_stream_out:
|
||||
return self._parse_model_nostream_resp(response, self.sep)
|
||||
else:
|
||||
return self._parse_model_stream_resp(response, self.sep)
|
||||
return self._parse_model_stream_resp(response, self.sep, skip_echo_len)
|
||||
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
"""
|
||||
@ -143,7 +142,7 @@ class BaseOutputParser(ABC):
|
||||
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
||||
return cleaned_output
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
"""
|
||||
parse the ai response info to user view
|
||||
Args:
|
||||
|
@ -128,6 +128,8 @@ class BaseChat(ABC):
|
||||
|
||||
def stream_call(self):
|
||||
payload = self.__call_base()
|
||||
|
||||
skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 1
|
||||
logger.info(f"Requert: \n{payload}")
|
||||
ai_response_text = ""
|
||||
try:
|
||||
@ -139,7 +141,7 @@ class BaseChat(ABC):
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len)
|
||||
|
||||
for resp_text_trunck in ai_response_text:
|
||||
show_info = resp_text_trunck
|
||||
|
@ -15,7 +15,7 @@ class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
def parse_view_response(self, ai_text, data) -> str:
|
||||
return ai_text["table"]
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
|
@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
|
||||
def parse_prompt_response(self, model_out_text) -> T:
|
||||
return model_out_text
|
||||
|
||||
def parse_view_response(self, ai_text) -> str:
|
||||
return super().parse_view_response(ai_text)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user