mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-26 12:20:39 +00:00
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
@@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
if self.current_dag_context.streaming_call:
|
||||
return self.parse_model_stream_resp_ex(input_value, 0)
|
||||
else:
|
||||
return self.parse_model_nostream_resp(input_value, "###")
|
||||
return self.parse_model_nostream_resp(input_value, "#####################")
|
||||
|
||||
|
||||
def _parse_model_response(response: ResponseTye):
|
||||
@@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL Output Parser"),
|
||||
name="default_sql_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_("Parse the SQL output of an LLM call."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Dict SQL Output"),
|
||||
"dict",
|
||||
dict,
|
||||
description=_("The dict output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
@@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
||||
|
||||
class SQLListOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL list output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL List Output Parser"),
|
||||
name="default_sql_list_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_(
|
||||
"Parse the SQL list output of an LLM call, mostly used for dashboard."
|
||||
),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("List SQL Output"),
|
||||
"list",
|
||||
dict,
|
||||
is_list=True,
|
||||
description=_("The list output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL list output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
json_objects = find_json_objects(model_out_text)
|
||||
json_count = len(json_objects)
|
||||
if json_count < 1:
|
||||
raise ValueError("Unable to obtain valid output.")
|
||||
|
||||
parsed_json_list = json_objects[0]
|
||||
if not isinstance(parsed_json_list, list):
|
||||
if isinstance(parsed_json_list, dict):
|
||||
return [parsed_json_list]
|
||||
else:
|
||||
raise ValueError("Invalid output format.")
|
||||
return parsed_json_list
|
||||
|
||||
Reference in New Issue
Block a user