Files
DB-GPT/pilot/out_parser/base.py
yhjun1026 f19ee46e74 fix(ChatExcel): ChatExcel OutParse Bug Fix
1.ChatExcel OutParse Bug Fix
2023-09-14 20:49:06 +08:00

257 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
from abc import ABC
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union
from pilot.configs.config import Config
from pilot.configs.model_config import LOGDIR
from pilot.model.base import ModelOutput
from pilot.utils import build_logger
T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput]
logger = build_logger("webserver", LOGDIR + "DbChatOutputParser.log")
CFG = Config()
class BaseOutputParser(ABC):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
def __init__(self, sep: str, is_stream_out: bool = True):
self.sep = sep
self.is_stream_out = is_stream_out
def __post_process_code(self, code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
data = _parse_model_response(chunk)
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
model_context = data.get("model_context")
has_echo = True
if model_context and "prompt_echo_len_char" in model_context:
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
has_echo = bool(model_context.get("echo", True))
if prompt_echo_len_char != -1:
skip_echo_len = prompt_echo_len_char
if data.get("error_code", 0) == 0:
if has_echo and ("vicuna" in CFG.LLM_MODEL or "llama-2" in CFG.LLM_MODEL):
# TODO Judging from model_context
# output = data["text"][skip_echo_len + 11:].strip()
output = data["text"][skip_echo_len:].strip()
elif has_echo and "guanaco" in CFG.LLM_MODEL:
# NO stream output
# output = data["text"][skip_echo_len + 2:].replace("<s>", "").strip()
# stream out output
output = data["text"][11:].replace("<s>", "").strip()
# TODO gorilla and falcon output
else:
output = data["text"].strip()
output = self.__post_process_code(output)
return output
else:
output = data["text"] + f" (error_code: {data['error_code']})"
return output
# TODO 后续和模型绑定
def parse_model_stream_resp(self, response, skip_echo_len):
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
if data["error_code"] == 0:
if "vicuna" in CFG.LLM_MODEL or "guanaco" in CFG.LLM_MODEL:
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = self.__post_process_code(output)
yield output
else:
output = data["text"] + f" (error_code: {data['error_code']})"
yield output
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
resp_obj_ex = _parse_model_response(response)
if isinstance(resp_obj_ex, str):
resp_obj_ex = json.loads(resp_obj_ex)
if resp_obj_ex["error_code"] == 0:
all_text = resp_obj_ex["text"]
### 解析返回文本获取AI回复部分
tmp_resp = all_text.split(sep)
last_index = -1
for i in range(len(tmp_resp)):
if tmp_resp[i].find("assistant:") != -1:
last_index = i
ai_response = tmp_resp[last_index]
ai_response = ai_response.replace("assistant:", "")
ai_response = ai_response.replace("Assistant:", "")
ai_response = ai_response.replace("ASSISTANT:", "")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "")
ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError("Model server error!code=" + resp_obj_ex["error_code"])
def __illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
for illegal_json_end in illegal_json_ends_1:
temp_json = temp_json.replace(illegal_json_end, " }")
for illegal_json_end in illegal_json_ends_2:
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s):
try:
# Get the dual-mode analysis first and get the maximum result
temp_json_simple = self.__json_interception(s)
temp_json_array = self.__json_interception(s, True)
if len(temp_json_simple) > len(temp_json_array):
temp_json = temp_json_simple
else:
temp_json = temp_json_array
if not temp_json:
temp_json = self.__json_interception(s)
temp_json = self.__illegal_json_ends(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
try:
if is_json_array:
i = s.find("[")
if i < 0:
return ""
count = 1
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
count += 1
if count == 0:
break
assert count == 0
return s[i : j + 1]
else:
i = s.find("{")
if i < 0:
return ""
count = 1
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0
return s[i : j + 1]
except Exception as e:
return ""
def parse_prompt_response(self, model_out_text) -> T:
"""
parse model out text to prompt define response
Args:
model_out_text:
Returns:
"""
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing:\n" + cleaned_output)
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = (
cleaned_output.strip()
.replace("\\n", " ")
.replace("\n", " ")
.replace("\\", " ")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(self, ai_text, data) -> str:
"""
parse the ai response info to user view
Args:
text:
Returns:
"""
return ai_text
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
# @property
# def _type(self) -> str:
# """Return the type key."""
# raise NotImplementedError(
# f"_type property is not implemented in class {self.__class__.__name__}."
# " This is required for serialization."
# )
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
def _parse_model_response(response: ResponseTye):
if isinstance(response, ModelOutput):
resp_obj_ex = asdict(response)
elif isinstance(response, str):
resp_obj_ex = json.loads(response)
elif isinstance(response, bytes):
if b"\0" in response:
response = response.replace(b"\0", b"")
resp_obj_ex = json.loads(response.decode())
else:
raise ValueError(f"Unsupported response type {type(response)}")
return resp_obj_ex