mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-14 06:26:18 +00:00
chat with plugin bug fix
This commit is contained in:
parent
bacc31658e
commit
6648a671df
@ -47,7 +47,7 @@ class BaseOutputParser(ABC):
|
|||||||
return code
|
return code
|
||||||
|
|
||||||
# TODO 后续和模型绑定
|
# TODO 后续和模型绑定
|
||||||
def _parse_model_stream_resp(self, response, sep: str):
|
def _parse_model_stream_resp(self, response, sep: str, skip_echo_len):
|
||||||
|
|
||||||
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
||||||
if chunk:
|
if chunk:
|
||||||
@ -56,9 +56,8 @@ class BaseOutputParser(ABC):
|
|||||||
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
|
||||||
"""
|
"""
|
||||||
if data["error_code"] == 0:
|
if data["error_code"] == 0:
|
||||||
if "vicuna" in CFG.LLM_MODEL:
|
if CFG.LLM_MODEL in ["vicuna", "guanaco"]:
|
||||||
|
output = data["text"][skip_echo_len:].strip()
|
||||||
output = data["text"].strip()
|
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
|
|
||||||
@ -97,7 +96,7 @@ class BaseOutputParser(ABC):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
raise ValueError("Model server error!code=" + respObj_ex["error_code"])
|
||||||
|
|
||||||
def parse_model_server_out(self, response):
|
def parse_model_server_out(self, response, skip_echo_len: int = 0):
|
||||||
"""
|
"""
|
||||||
parse the model server http response
|
parse the model server http response
|
||||||
Args:
|
Args:
|
||||||
@ -109,7 +108,7 @@ class BaseOutputParser(ABC):
|
|||||||
if not self.is_stream_out:
|
if not self.is_stream_out:
|
||||||
return self._parse_model_nostream_resp(response, self.sep)
|
return self._parse_model_nostream_resp(response, self.sep)
|
||||||
else:
|
else:
|
||||||
return self._parse_model_stream_resp(response, self.sep)
|
return self._parse_model_stream_resp(response, self.sep, skip_echo_len)
|
||||||
|
|
||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
"""
|
"""
|
||||||
@ -143,7 +142,7 @@ class BaseOutputParser(ABC):
|
|||||||
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
cleaned_output = cleaned_output.strip().replace('\n', '').replace('\\n', '').replace('\\', '').replace('\\', '')
|
||||||
return cleaned_output
|
return cleaned_output
|
||||||
|
|
||||||
def parse_view_response(self, ai_text) -> str:
|
def parse_view_response(self, ai_text, data) -> str:
|
||||||
"""
|
"""
|
||||||
parse the ai response info to user view
|
parse the ai response info to user view
|
||||||
Args:
|
Args:
|
||||||
|
@ -128,6 +128,8 @@ class BaseChat(ABC):
|
|||||||
|
|
||||||
def stream_call(self):
|
def stream_call(self):
|
||||||
payload = self.__call_base()
|
payload = self.__call_base()
|
||||||
|
|
||||||
|
skip_echo_len = len(payload.get('prompt').replace("</s>", " ")) + 1
|
||||||
logger.info(f"Requert: \n{payload}")
|
logger.info(f"Requert: \n{payload}")
|
||||||
ai_response_text = ""
|
ai_response_text = ""
|
||||||
try:
|
try:
|
||||||
@ -139,7 +141,7 @@ class BaseChat(ABC):
|
|||||||
timeout=120,
|
timeout=120,
|
||||||
)
|
)
|
||||||
|
|
||||||
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response)
|
ai_response_text = self.prompt_template.output_parser.parse_model_server_out(response, skip_echo_len)
|
||||||
|
|
||||||
for resp_text_trunck in ai_response_text:
|
for resp_text_trunck in ai_response_text:
|
||||||
show_info = resp_text_trunck
|
show_info = resp_text_trunck
|
||||||
|
@ -15,7 +15,7 @@ class NormalChatOutputParser(BaseOutputParser):
|
|||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
def parse_view_response(self, ai_text) -> str:
|
def parse_view_response(self, ai_text, data) -> str:
|
||||||
return ai_text["table"]
|
return ai_text["table"]
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
|
@ -15,8 +15,5 @@ class NormalChatOutputParser(BaseOutputParser):
|
|||||||
def parse_prompt_response(self, model_out_text) -> T:
|
def parse_prompt_response(self, model_out_text) -> T:
|
||||||
return model_out_text
|
return model_out_text
|
||||||
|
|
||||||
def parse_view_response(self, ai_text) -> str:
|
|
||||||
return super().parse_view_response(ai_text)
|
|
||||||
|
|
||||||
def get_format_instructions(self) -> str:
|
def get_format_instructions(self) -> str:
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user