refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -92,7 +92,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
)
def __illegal_json_ends(self, s):
def _illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
@@ -102,25 +102,25 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s):
def _extract_json(self, s):
try:
# Get the dual-mode analysis first and get the maximum result
temp_json_simple = self.__json_interception(s)
temp_json_array = self.__json_interception(s, True)
temp_json_simple = self._json_interception(s)
temp_json_array = self._json_interception(s, True)
if len(temp_json_simple) > len(temp_json_array):
temp_json = temp_json_simple
else:
temp_json = temp_json_array
if not temp_json:
temp_json = self.__json_interception(s)
temp_json = self._json_interception(s)
temp_json = self.__illegal_json_ends(temp_json)
temp_json = self._illegal_json_ends(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
def _json_interception(self, s, is_json_array: bool = False):
try:
if is_json_array:
i = s.find("[")
@@ -176,7 +176,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing:\n" + cleaned_output)
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = self._extract_json(cleaned_output)
if not cleaned_output or len(cleaned_output) <= 0:
return model_out_text
@@ -188,7 +188,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
.replace("\\", " ")
.replace("\_", "_")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
cleaned_output = self._illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(
@@ -208,20 +208,6 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
# @property
# def _type(self) -> str:
# """Return the type key."""
# raise NotImplementedError(
# f"_type property is not implemented in class {self.__class__.__name__}."
# " This is required for serialization."
# )
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
async def map(self, input_value: ModelOutput) -> Any:
"""Parse the output of an LLM call.