refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

View File

@@ -0,0 +1,253 @@
from __future__ import annotations
import json
from abc import ABC
import logging
from dataclasses import asdict
from typing import Any, Dict, TypeVar, Union
from dbgpt.core.awel import MapOperator
from dbgpt.core import ModelOutput
T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput]
logger = logging.getLogger(__name__)
class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
def __init__(self, is_stream_out: bool = True, **kwargs):
super().__init__(**kwargs)
self.is_stream_out = is_stream_out
self.data_schema = None
def update(self, data_schema):
self.data_schema = data_schema
def __post_process_code(self, code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
data = _parse_model_response(chunk)
""" TODO Multi mode output handler, rewrite this for multi model, use adapter mode.
"""
model_context = data.get("model_context")
has_echo = True
if model_context and "prompt_echo_len_char" in model_context:
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
has_echo = bool(model_context.get("echo", False))
if prompt_echo_len_char != -1:
skip_echo_len = prompt_echo_len_char
if data.get("error_code", 0) == 0:
if has_echo:
# TODO Judging from model_context
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = self.__post_process_code(output)
return output
else:
output = data["text"] + f" (error_code: {data['error_code']})"
return output
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
resp_obj_ex = _parse_model_response(response)
if isinstance(resp_obj_ex, str):
resp_obj_ex = json.loads(resp_obj_ex)
if resp_obj_ex["error_code"] == 0:
all_text = resp_obj_ex["text"]
# Parse the returned text to get the AI reply part
tmp_resp = all_text.split(sep)
last_index = -1
for i in range(len(tmp_resp)):
if tmp_resp[i].find("assistant:") != -1:
last_index = i
ai_response = tmp_resp[last_index]
ai_response = ai_response.replace("assistant:", "")
ai_response = ai_response.replace("Assistant:", "")
ai_response = ai_response.replace("ASSISTANT:", "")
ai_response = ai_response.replace("\_", "_")
ai_response = ai_response.replace("\*", "*")
ai_response = ai_response.replace("\t", "")
ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError(
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
)
def __illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
for illegal_json_end in illegal_json_ends_1:
temp_json = temp_json.replace(illegal_json_end, " }")
for illegal_json_end in illegal_json_ends_2:
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s):
try:
# Get the dual-mode analysis first and get the maximum result
temp_json_simple = self.__json_interception(s)
temp_json_array = self.__json_interception(s, True)
if len(temp_json_simple) > len(temp_json_array):
temp_json = temp_json_simple
else:
temp_json = temp_json_array
if not temp_json:
temp_json = self.__json_interception(s)
temp_json = self.__illegal_json_ends(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
try:
if is_json_array:
i = s.find("[")
if i < 0:
return ""
count = 1
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
count += 1
if count == 0:
break
assert count == 0
return s[i : j + 1]
else:
i = s.find("{")
if i < 0:
return ""
count = 1
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0
return s[i : j + 1]
except Exception as e:
return ""
def parse_prompt_response(self, model_out_text) -> T:
"""
parse model out text to prompt define response
Args:
model_out_text:
Returns:
"""
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing:\n" + cleaned_output)
cleaned_output = self.__extract_json(cleaned_output)
if not cleaned_output or len(cleaned_output) <= 0:
return model_out_text
cleaned_output = (
cleaned_output.strip()
.replace("\\n", " ")
.replace("\n", " ")
.replace("\\", " ")
.replace("\_", "_")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(
self, ai_text, data, parse_prompt_response: Any = None
) -> str:
"""
parse the ai response info to user view
Args:
text:
Returns:
"""
return ai_text
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
# @property
# def _type(self) -> str:
# """Return the type key."""
# raise NotImplementedError(
# f"_type property is not implemented in class {self.__class__.__name__}."
# " This is required for serialization."
# )
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
async def map(self, input_value: ModelOutput) -> Any:
"""Parse the output of an LLM call.
Args:
input_value (ModelOutput): The output of an LLM call.
Returns:
Any: The parsed output of an LLM call.
"""
if self.current_dag_context.streaming_call:
return self.parse_model_stream_resp_ex(input_value, 0)
else:
return self.parse_model_nostream_resp(input_value, "###")
def _parse_model_response(response: ResponseTye):
if response is None:
resp_obj_ex = ""
elif isinstance(response, ModelOutput):
resp_obj_ex = asdict(response)
elif isinstance(response, str):
resp_obj_ex = json.loads(response)
elif isinstance(response, bytes):
if b"\0" in response:
response = response.replace(b"\0", b"")
resp_obj_ex = json.loads(response.decode())
else:
raise ValueError(f"Unsupported response type {type(response)}")
return resp_obj_ex