DB-GPT/dbgpt/core/interface/output_parser.py
2024-08-29 19:39:42 +08:00

390 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""The output parser is used to parse the output of an LLM call.
TODO: Make this more general and clear.
"""
from __future__ import annotations
import json
import logging
from abc import ABC
from dataclasses import asdict
from typing import Any, TypeVar, Union
from dbgpt.core import ModelOutput
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
ViewMetadata,
)
from dbgpt.util.i18n_utils import _
T = TypeVar("T")
ResponseTye = Union[str, bytes, ModelOutput]
logger = logging.getLogger(__name__)
class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
metadata = ViewMetadata(
label=_("Base Output Operator"),
name="base_output_operator",
operator_type=OperatorType.TRANSFORM_STREAM,
category=OperatorCategory.OUTPUT_PARSER,
description=_("The base LLM out parse."),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
is_list=True,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("Model Output"),
"model_output",
str,
is_list=True,
description=_("The model output after parsing."),
)
],
)
def __init__(self, is_stream_out: bool = True, **kwargs):
"""Create a new output parser."""
super().__init__(**kwargs)
self.is_stream_out = is_stream_out
self.data_schema = None
def update(self, data_schema):
"""Update the data schema.
TODO: Remove this method.
"""
self.data_schema = data_schema
def __post_process_code(self, code):
sep = "\n```"
if sep in code:
blocks = code.split(sep)
if len(blocks) % 2 == 1:
for i in range(1, len(blocks), 2):
blocks[i] = blocks[i].replace("\\_", "_")
code = sep.join(blocks)
return code
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
"""Parse the output of an LLM call.
Args:
chunk (ResponseTye): The output of an LLM call.
skip_echo_len (int): The length of the prompt to skip.
"""
data = _parse_model_response(chunk)
# TODO: Multi mode output handler, rewrite this for multi model, use adapter
# mode.
model_context = data.get("model_context")
has_echo = False
if model_context and "prompt_echo_len_char" in model_context:
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
has_echo = bool(model_context.get("echo", False))
if prompt_echo_len_char != -1:
skip_echo_len = prompt_echo_len_char
if data.get("error_code", 0) == 0:
if has_echo:
# TODO Judging from model_context
output = data["text"][skip_echo_len:].strip()
else:
output = data["text"].strip()
output = self.__post_process_code(output)
return output
else:
output = data["text"] + f" (error_code: {data['error_code']})"
return output
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
resp_obj_ex = _parse_model_response(response)
if isinstance(resp_obj_ex, str):
resp_obj_ex = json.loads(resp_obj_ex)
if resp_obj_ex["error_code"] == 0:
all_text = resp_obj_ex["text"]
# Parse the returned text to get the AI reply part
tmp_resp = all_text.split(sep)
last_index = -1
for i in range(len(tmp_resp)):
if tmp_resp[i].find("assistant:") != -1:
last_index = i
ai_response = tmp_resp[last_index]
ai_response = ai_response.replace("assistant:", "")
ai_response = ai_response.replace("Assistant:", "")
ai_response = ai_response.replace("ASSISTANT:", "")
ai_response = ai_response.replace("\\_", "_")
ai_response = ai_response.replace("\\*", "*")
ai_response = ai_response.replace("\t", "")
# ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
print("un_stream ai response:", ai_response)
return ai_response
else:
raise ValueError(
f"Model server error!code={resp_obj_ex['error_code']}, error msg is "
f"{resp_obj_ex['text']}"
)
def _illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
for illegal_json_end in illegal_json_ends_1:
temp_json = temp_json.replace(illegal_json_end, " }")
for illegal_json_end in illegal_json_ends_2:
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def _extract_json(self, s):
try:
# Get the dual-mode analysis first and get the maximum result
temp_json_simple = self._json_interception(s)
temp_json_array = self._json_interception(s, True)
if len(temp_json_simple) > len(temp_json_array):
temp_json = temp_json_simple
else:
temp_json = temp_json_array
if not temp_json:
temp_json = self._json_interception(s)
temp_json = self._illegal_json_ends(temp_json)
return temp_json
except Exception:
raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def _json_interception(self, s, is_json_array: bool = False):
try:
if is_json_array:
i = s.find("[")
if i < 0:
return ""
count = 1
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "]":
count -= 1
elif c == "[":
count += 1
if count == 0:
break
assert count == 0
return s[i : j + 1]
else:
i = s.find("{")
if i < 0:
return ""
count = 1
for j, c in enumerate(s[i + 1 :], start=i + 1):
if c == "}":
count -= 1
elif c == "{":
count += 1
if count == 0:
break
assert count == 0
return s[i : j + 1]
except Exception:
return ""
def parse_prompt_response(self, model_out_text) -> Any:
"""Parse model out text to prompt define response.
Args:
model_out_text: The output of an LLM call.
Returns:
Any: The parsed output of an LLM call.
"""
cleaned_output = model_out_text.rstrip()
if "```json" in cleaned_output:
_, cleaned_output = cleaned_output.split("```json")
# if "```" in cleaned_output:
# cleaned_output, _ = cleaned_output.split("```")
if cleaned_output.startswith("```json"):
cleaned_output = cleaned_output[len("```json") :]
if cleaned_output.startswith("```"):
cleaned_output = cleaned_output[len("```") :]
if cleaned_output.endswith("```"):
cleaned_output = cleaned_output[: -len("```")]
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing:\n" + cleaned_output)
cleaned_output = self._extract_json(cleaned_output)
if not cleaned_output or len(cleaned_output) <= 0:
return model_out_text
cleaned_output = (
cleaned_output.strip()
.replace("\\n", " ")
.replace("\n", " ")
.replace("\\", " ")
.replace("\\_", "_")
)
cleaned_output = self._illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(
self, ai_text, data, parse_prompt_response: Any = None
) -> str:
"""Parse the AI response info to user view.
Args:
ai_text (str): The output of an LLM call.
data (dict): The data has been handled by some scene.
parse_prompt_response (Any): The prompt response has been parsed.
Returns:
str: The parsed output of an LLM call.
"""
return ai_text
def get_format_instructions(self) -> str:
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
async def map(self, input_value: ModelOutput) -> Any:
"""Parse the output of an LLM call.
Args:
input_value (ModelOutput): The output of an LLM call.
Returns:
Any: The parsed output of an LLM call.
"""
if self.current_dag_context.streaming_call:
return self.parse_model_stream_resp_ex(input_value, 0)
else:
return self.parse_model_nostream_resp(input_value, "#####################")
def _parse_model_response(response: ResponseTye):
if response is None:
resp_obj_ex = ""
elif isinstance(response, ModelOutput):
resp_obj_ex = asdict(response)
elif isinstance(response, str):
resp_obj_ex = json.loads(response)
elif isinstance(response, bytes):
if b"\0" in response:
response = response.replace(b"\0", b"")
resp_obj_ex = json.loads(response.decode())
else:
raise ValueError(f"Unsupported response type {type(response)}")
return resp_obj_ex
class SQLOutputParser(BaseOutputParser):
"""Parse the SQL output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL Output Parser"),
name="default_sql_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_("Parse the SQL output of an LLM call."),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("Dict SQL Output"),
"dict",
dict,
description=_("The dict output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True)
class SQLListOutputParser(BaseOutputParser):
"""Parse the SQL list output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL List Output Parser"),
name="default_sql_list_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_(
"Parse the SQL list output of an LLM call, mostly used for dashboard."
),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("List SQL Output"),
"list",
dict,
is_list=True,
description=_("The list output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL list output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
from dbgpt.util.json_utils import find_json_objects
model_out_text = super().parse_model_nostream_resp(response, sep)
json_objects = find_json_objects(model_out_text)
json_count = len(json_objects)
if json_count < 1:
raise ValueError("Unable to obtain valid output.")
parsed_json_list = json_objects[0]
if not isinstance(parsed_json_list, list):
if isinstance(parsed_json_list, dict):
return [parsed_json_list]
else:
raise ValueError("Invalid output format.")
return parsed_json_list