mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 10:20:01 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -92,7 +92,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
|
||||
)
|
||||
|
||||
def __illegal_json_ends(self, s):
|
||||
def _illegal_json_ends(self, s):
|
||||
temp_json = s
|
||||
illegal_json_ends_1 = [", }", ",}"]
|
||||
illegal_json_ends_2 = ", ]", ",]"
|
||||
@@ -102,25 +102,25 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
temp_json = temp_json.replace(illegal_json_end, " ]")
|
||||
return temp_json
|
||||
|
||||
def __extract_json(self, s):
|
||||
def _extract_json(self, s):
|
||||
try:
|
||||
# Get the dual-mode analysis first and get the maximum result
|
||||
temp_json_simple = self.__json_interception(s)
|
||||
temp_json_array = self.__json_interception(s, True)
|
||||
temp_json_simple = self._json_interception(s)
|
||||
temp_json_array = self._json_interception(s, True)
|
||||
if len(temp_json_simple) > len(temp_json_array):
|
||||
temp_json = temp_json_simple
|
||||
else:
|
||||
temp_json = temp_json_array
|
||||
|
||||
if not temp_json:
|
||||
temp_json = self.__json_interception(s)
|
||||
temp_json = self._json_interception(s)
|
||||
|
||||
temp_json = self.__illegal_json_ends(temp_json)
|
||||
temp_json = self._illegal_json_ends(temp_json)
|
||||
return temp_json
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
||||
|
||||
def __json_interception(self, s, is_json_array: bool = False):
|
||||
def _json_interception(self, s, is_json_array: bool = False):
|
||||
try:
|
||||
if is_json_array:
|
||||
i = s.find("[")
|
||||
@@ -176,7 +176,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
cleaned_output = cleaned_output.strip()
|
||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||
logger.info("illegal json processing:\n" + cleaned_output)
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = self._extract_json(cleaned_output)
|
||||
|
||||
if not cleaned_output or len(cleaned_output) <= 0:
|
||||
return model_out_text
|
||||
@@ -188,7 +188,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
.replace("\\", " ")
|
||||
.replace("\_", "_")
|
||||
)
|
||||
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
||||
cleaned_output = self._illegal_json_ends(cleaned_output)
|
||||
return cleaned_output
|
||||
|
||||
def parse_view_response(
|
||||
@@ -208,20 +208,6 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @property
|
||||
# def _type(self) -> str:
|
||||
# """Return the type key."""
|
||||
# raise NotImplementedError(
|
||||
# f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
# " This is required for serialization."
|
||||
# )
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict()
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> Any:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
|
Reference in New Issue
Block a user