mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-25 13:06:53 +00:00
390 lines
13 KiB
Python
390 lines
13 KiB
Python
"""The output parser is used to parse the output of an LLM call.
|
||
|
||
TODO: Make this more general and clear.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from abc import ABC
|
||
from dataclasses import asdict
|
||
from typing import Any, TypeVar, Union
|
||
|
||
from dbgpt.core import ModelOutput
|
||
from dbgpt.core.awel import MapOperator
|
||
from dbgpt.core.awel.flow import (
|
||
TAGS_ORDER_HIGH,
|
||
IOField,
|
||
OperatorCategory,
|
||
OperatorType,
|
||
ViewMetadata,
|
||
)
|
||
from dbgpt.util.i18n_utils import _
|
||
|
||
T = TypeVar("T")
|
||
ResponseTye = Union[str, bytes, ModelOutput]
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||
"""Class to parse the output of an LLM call.
|
||
|
||
Output parsers help structure language model responses.
|
||
"""
|
||
|
||
metadata = ViewMetadata(
|
||
label=_("Base Output Operator"),
|
||
name="base_output_operator",
|
||
operator_type=OperatorType.TRANSFORM_STREAM,
|
||
category=OperatorCategory.OUTPUT_PARSER,
|
||
description=_("The base LLM out parse."),
|
||
parameters=[],
|
||
inputs=[
|
||
IOField.build_from(
|
||
_("Model Output"),
|
||
"model_output",
|
||
ModelOutput,
|
||
is_list=True,
|
||
description=_("The model output of upstream."),
|
||
)
|
||
],
|
||
outputs=[
|
||
IOField.build_from(
|
||
_("Model Output"),
|
||
"model_output",
|
||
str,
|
||
is_list=True,
|
||
description=_("The model output after parsing."),
|
||
)
|
||
],
|
||
)
|
||
|
||
def __init__(self, is_stream_out: bool = True, **kwargs):
|
||
"""Create a new output parser."""
|
||
super().__init__(**kwargs)
|
||
self.is_stream_out = is_stream_out
|
||
self.data_schema = None
|
||
|
||
def update(self, data_schema):
|
||
"""Update the data schema.
|
||
|
||
TODO: Remove this method.
|
||
"""
|
||
self.data_schema = data_schema
|
||
|
||
def __post_process_code(self, code):
|
||
sep = "\n```"
|
||
if sep in code:
|
||
blocks = code.split(sep)
|
||
if len(blocks) % 2 == 1:
|
||
for i in range(1, len(blocks), 2):
|
||
blocks[i] = blocks[i].replace("\\_", "_")
|
||
code = sep.join(blocks)
|
||
return code
|
||
|
||
def parse_model_stream_resp_ex(self, chunk: ResponseTye, skip_echo_len):
|
||
"""Parse the output of an LLM call.
|
||
|
||
Args:
|
||
chunk (ResponseTye): The output of an LLM call.
|
||
skip_echo_len (int): The length of the prompt to skip.
|
||
"""
|
||
data = _parse_model_response(chunk)
|
||
# TODO: Multi mode output handler, rewrite this for multi model, use adapter
|
||
# mode.
|
||
|
||
model_context = data.get("model_context")
|
||
has_echo = False
|
||
if model_context and "prompt_echo_len_char" in model_context:
|
||
prompt_echo_len_char = int(model_context.get("prompt_echo_len_char", -1))
|
||
has_echo = bool(model_context.get("echo", False))
|
||
if prompt_echo_len_char != -1:
|
||
skip_echo_len = prompt_echo_len_char
|
||
|
||
if data.get("error_code", 0) == 0:
|
||
if has_echo:
|
||
# TODO Judging from model_context
|
||
output = data["text"][skip_echo_len:].strip()
|
||
else:
|
||
output = data["text"].strip()
|
||
|
||
output = self.__post_process_code(output)
|
||
return output
|
||
else:
|
||
output = data["text"] + f" (error_code: {data['error_code']})"
|
||
return output
|
||
|
||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||
"""Parse the output of an LLM call."""
|
||
resp_obj_ex = _parse_model_response(response)
|
||
if isinstance(resp_obj_ex, str):
|
||
resp_obj_ex = json.loads(resp_obj_ex)
|
||
if resp_obj_ex["error_code"] == 0:
|
||
all_text = resp_obj_ex["text"]
|
||
# Parse the returned text to get the AI reply part
|
||
tmp_resp = all_text.split(sep)
|
||
last_index = -1
|
||
for i in range(len(tmp_resp)):
|
||
if tmp_resp[i].find("assistant:") != -1:
|
||
last_index = i
|
||
ai_response = tmp_resp[last_index]
|
||
ai_response = ai_response.replace("assistant:", "")
|
||
ai_response = ai_response.replace("Assistant:", "")
|
||
ai_response = ai_response.replace("ASSISTANT:", "")
|
||
ai_response = ai_response.replace("\\_", "_")
|
||
ai_response = ai_response.replace("\\*", "*")
|
||
ai_response = ai_response.replace("\t", "")
|
||
|
||
# ai_response = ai_response.strip().replace("\\n", " ").replace("\n", " ")
|
||
print("un_stream ai response:", ai_response)
|
||
return ai_response
|
||
else:
|
||
raise ValueError(
|
||
f"Model server error!code={resp_obj_ex['error_code']}, error msg is "
|
||
f"{resp_obj_ex['text']}"
|
||
)
|
||
|
||
def _illegal_json_ends(self, s):
|
||
temp_json = s
|
||
illegal_json_ends_1 = [", }", ",}"]
|
||
illegal_json_ends_2 = ", ]", ",]"
|
||
for illegal_json_end in illegal_json_ends_1:
|
||
temp_json = temp_json.replace(illegal_json_end, " }")
|
||
for illegal_json_end in illegal_json_ends_2:
|
||
temp_json = temp_json.replace(illegal_json_end, " ]")
|
||
return temp_json
|
||
|
||
def _extract_json(self, s):
|
||
try:
|
||
# Get the dual-mode analysis first and get the maximum result
|
||
temp_json_simple = self._json_interception(s)
|
||
temp_json_array = self._json_interception(s, True)
|
||
if len(temp_json_simple) > len(temp_json_array):
|
||
temp_json = temp_json_simple
|
||
else:
|
||
temp_json = temp_json_array
|
||
|
||
if not temp_json:
|
||
temp_json = self._json_interception(s)
|
||
|
||
temp_json = self._illegal_json_ends(temp_json)
|
||
return temp_json
|
||
except Exception:
|
||
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
||
|
||
def _json_interception(self, s, is_json_array: bool = False):
|
||
try:
|
||
if is_json_array:
|
||
i = s.find("[")
|
||
if i < 0:
|
||
return ""
|
||
count = 1
|
||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||
if c == "]":
|
||
count -= 1
|
||
elif c == "[":
|
||
count += 1
|
||
if count == 0:
|
||
break
|
||
assert count == 0
|
||
return s[i : j + 1]
|
||
else:
|
||
i = s.find("{")
|
||
if i < 0:
|
||
return ""
|
||
count = 1
|
||
for j, c in enumerate(s[i + 1 :], start=i + 1):
|
||
if c == "}":
|
||
count -= 1
|
||
elif c == "{":
|
||
count += 1
|
||
if count == 0:
|
||
break
|
||
assert count == 0
|
||
return s[i : j + 1]
|
||
except Exception:
|
||
return ""
|
||
|
||
def parse_prompt_response(self, model_out_text) -> Any:
|
||
"""Parse model out text to prompt define response.
|
||
|
||
Args:
|
||
model_out_text: The output of an LLM call.
|
||
|
||
Returns:
|
||
Any: The parsed output of an LLM call.
|
||
"""
|
||
cleaned_output = model_out_text.rstrip()
|
||
if "```json" in cleaned_output:
|
||
_, cleaned_output = cleaned_output.split("```json")
|
||
# if "```" in cleaned_output:
|
||
# cleaned_output, _ = cleaned_output.split("```")
|
||
if cleaned_output.startswith("```json"):
|
||
cleaned_output = cleaned_output[len("```json") :]
|
||
if cleaned_output.startswith("```"):
|
||
cleaned_output = cleaned_output[len("```") :]
|
||
if cleaned_output.endswith("```"):
|
||
cleaned_output = cleaned_output[: -len("```")]
|
||
cleaned_output = cleaned_output.strip()
|
||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||
logger.info("illegal json processing:\n" + cleaned_output)
|
||
cleaned_output = self._extract_json(cleaned_output)
|
||
|
||
if not cleaned_output or len(cleaned_output) <= 0:
|
||
return model_out_text
|
||
|
||
cleaned_output = (
|
||
cleaned_output.strip()
|
||
.replace("\\n", " ")
|
||
.replace("\n", " ")
|
||
.replace("\\", " ")
|
||
.replace("\\_", "_")
|
||
)
|
||
cleaned_output = self._illegal_json_ends(cleaned_output)
|
||
return cleaned_output
|
||
|
||
def parse_view_response(
|
||
self, ai_text, data, parse_prompt_response: Any = None
|
||
) -> str:
|
||
"""Parse the AI response info to user view.
|
||
|
||
Args:
|
||
ai_text (str): The output of an LLM call.
|
||
data (dict): The data has been handled by some scene.
|
||
parse_prompt_response (Any): The prompt response has been parsed.
|
||
|
||
Returns:
|
||
str: The parsed output of an LLM call.
|
||
|
||
"""
|
||
return ai_text
|
||
|
||
def get_format_instructions(self) -> str:
|
||
"""Instructions on how the LLM output should be formatted."""
|
||
raise NotImplementedError
|
||
|
||
async def map(self, input_value: ModelOutput) -> Any:
|
||
"""Parse the output of an LLM call.
|
||
|
||
Args:
|
||
input_value (ModelOutput): The output of an LLM call.
|
||
|
||
Returns:
|
||
Any: The parsed output of an LLM call.
|
||
"""
|
||
if self.current_dag_context.streaming_call:
|
||
return self.parse_model_stream_resp_ex(input_value, 0)
|
||
else:
|
||
return self.parse_model_nostream_resp(input_value, "#####################")
|
||
|
||
|
||
def _parse_model_response(response: ResponseTye):
|
||
if response is None:
|
||
resp_obj_ex = ""
|
||
elif isinstance(response, ModelOutput):
|
||
resp_obj_ex = asdict(response)
|
||
elif isinstance(response, str):
|
||
resp_obj_ex = json.loads(response)
|
||
elif isinstance(response, bytes):
|
||
if b"\0" in response:
|
||
response = response.replace(b"\0", b"")
|
||
resp_obj_ex = json.loads(response.decode())
|
||
else:
|
||
raise ValueError(f"Unsupported response type {type(response)}")
|
||
return resp_obj_ex
|
||
|
||
|
||
class SQLOutputParser(BaseOutputParser):
|
||
"""Parse the SQL output of an LLM call."""
|
||
|
||
metadata = ViewMetadata(
|
||
label=_("SQL Output Parser"),
|
||
name="default_sql_output_parser",
|
||
category=OperatorCategory.OUTPUT_PARSER,
|
||
description=_("Parse the SQL output of an LLM call."),
|
||
parameters=[],
|
||
inputs=[
|
||
IOField.build_from(
|
||
_("Model Output"),
|
||
"model_output",
|
||
ModelOutput,
|
||
description=_("The model output of upstream."),
|
||
)
|
||
],
|
||
outputs=[
|
||
IOField.build_from(
|
||
_("Dict SQL Output"),
|
||
"dict",
|
||
dict,
|
||
description=_("The dict output after parsing."),
|
||
)
|
||
],
|
||
tags={"order": TAGS_ORDER_HIGH},
|
||
)
|
||
|
||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||
"""Create a new SQL output parser."""
|
||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||
|
||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||
"""Parse the output of an LLM call."""
|
||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||
clean_str = super().parse_prompt_response(model_out_text)
|
||
return json.loads(clean_str, strict=True)
|
||
|
||
|
||
class SQLListOutputParser(BaseOutputParser):
|
||
"""Parse the SQL list output of an LLM call."""
|
||
|
||
metadata = ViewMetadata(
|
||
label=_("SQL List Output Parser"),
|
||
name="default_sql_list_output_parser",
|
||
category=OperatorCategory.OUTPUT_PARSER,
|
||
description=_(
|
||
"Parse the SQL list output of an LLM call, mostly used for dashboard."
|
||
),
|
||
parameters=[],
|
||
inputs=[
|
||
IOField.build_from(
|
||
_("Model Output"),
|
||
"model_output",
|
||
ModelOutput,
|
||
description=_("The model output of upstream."),
|
||
)
|
||
],
|
||
outputs=[
|
||
IOField.build_from(
|
||
_("List SQL Output"),
|
||
"list",
|
||
dict,
|
||
is_list=True,
|
||
description=_("The list output after parsing."),
|
||
)
|
||
],
|
||
tags={"order": TAGS_ORDER_HIGH},
|
||
)
|
||
|
||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||
"""Create a new SQL list output parser."""
|
||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||
|
||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||
"""Parse the output of an LLM call."""
|
||
from dbgpt.util.json_utils import find_json_objects
|
||
|
||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||
json_objects = find_json_objects(model_out_text)
|
||
json_count = len(json_objects)
|
||
if json_count < 1:
|
||
raise ValueError("Unable to obtain valid output.")
|
||
|
||
parsed_json_list = json_objects[0]
|
||
if not isinstance(parsed_json_list, list):
|
||
if isinstance(parsed_json_list, dict):
|
||
return [parsed_json_list]
|
||
else:
|
||
raise ValueError("Invalid output format.")
|
||
return parsed_json_list
|