chore: Merge latest code

This commit is contained in:
Fangyin Cheng
2024-08-30 15:00:14 +08:00
parent 471689ba20
commit c67b50052d
26 changed files with 1643 additions and 68 deletions

View File

@@ -10,6 +10,7 @@ from ..util.parameter_util import ( # noqa: F401
VariablesDynamicOptions,
)
from .base import ( # noqa: F401
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
@@ -33,6 +34,7 @@ __ALL__ = [
"ResourceCategory",
"ResourceType",
"OperatorType",
"TAGS_ORDER_HIGH",
"IOField",
"BaseDynamicOptions",
"FunctionDynamicOptions",

View File

@@ -40,6 +40,9 @@ _BASIC_TYPES = [str, int, float, bool, dict, list, set]
T = TypeVar("T", bound="ViewMixin")
TM = TypeVar("TM", bound="TypeMetadata")
TAGS_ORDER_HIGH = "higher-order"
TAGS_ORDER_FIRST = "first-order"
def _get_type_name(type_: Type[Any]) -> str:
"""Get the type name of the type.
@@ -143,6 +146,8 @@ _OPERATOR_CATEGORY_DETAIL = {
"agent": _CategoryDetail("Agent", "The agent operator"),
"rag": _CategoryDetail("RAG", "The RAG operator"),
"experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"),
"database": _CategoryDetail("Database", "Interact with the database"),
"type_converter": _CategoryDetail("Type Converter", "Convert the type"),
"example": _CategoryDetail("Example", "Example operator"),
}
@@ -159,6 +164,8 @@ class OperatorCategory(str, Enum):
AGENT = "agent"
RAG = "rag"
EXPERIMENTAL = "experimental"
DATABASE = "database"
TYPE_CONVERTER = "type_converter"
EXAMPLE = "example"
def label(self) -> str:
@@ -202,6 +209,7 @@ _RESOURCE_CATEGORY_DETAIL = {
"embeddings": _CategoryDetail("Embeddings", "The embeddings resource"),
"rag": _CategoryDetail("RAG", "The resource"),
"vector_store": _CategoryDetail("Vector Store", "The vector store resource"),
"database": _CategoryDetail("Database", "Interact with the database"),
"example": _CategoryDetail("Example", "The example resource"),
}
@@ -219,6 +227,7 @@ class ResourceCategory(str, Enum):
EMBEDDINGS = "embeddings"
RAG = "rag"
VECTOR_STORE = "vector_store"
DATABASE = "database"
EXAMPLE = "example"
def label(self) -> str:
@@ -372,32 +381,41 @@ class Parameter(TypeMetadata, Serializable):
"value": values.get("value"),
"default": values.get("default"),
}
is_list = values.get("is_list") or False
if type_cls:
for k, v in to_handle_values.items():
if v:
handled_v = cls._covert_to_real_type(type_cls, v)
handled_v = cls._covert_to_real_type(type_cls, v, is_list)
values[k] = handled_v
return values
@classmethod
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any:
if type_cls and v is not None:
typed_value: Any = v
def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any:
def _parse_single_value(vv: Any) -> Any:
typed_value: Any = vv
try:
# Try to convert the value to the type.
if type_cls == "builtins.str":
typed_value = str(v)
typed_value = str(vv)
elif type_cls == "builtins.int":
typed_value = int(v)
typed_value = int(vv)
elif type_cls == "builtins.float":
typed_value = float(v)
typed_value = float(vv)
elif type_cls == "builtins.bool":
if str(v).lower() in ["false", "0", "", "no", "off"]:
if str(vv).lower() in ["false", "0", "", "no", "off"]:
return False
typed_value = bool(v)
typed_value = bool(vv)
return typed_value
except ValueError:
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}")
raise ValidationError(f"Value '{vv}' is not valid for type {type_cls}")
if type_cls and v is not None:
if not is_list:
_parse_single_value(v)
else:
if not isinstance(v, list):
raise ValidationError(f"Value '{v}' is not a list.")
return [_parse_single_value(vv) for vv in v]
return v
def get_typed_value(self) -> Any:
@@ -413,11 +431,11 @@ class Parameter(TypeMetadata, Serializable):
if is_variables and self.value is not None and isinstance(self.value, str):
return VariablesPlaceHolder(self.name, self.value)
else:
return self._covert_to_real_type(self.type_cls, self.value)
return self._covert_to_real_type(self.type_cls, self.value, self.is_list)
def get_typed_default(self) -> Any:
"""Get the typed default."""
return self._covert_to_real_type(self.type_cls, self.default)
return self._covert_to_real_type(self.type_cls, self.default, self.is_list)
@classmethod
def build_from(
@@ -499,7 +517,10 @@ class Parameter(TypeMetadata, Serializable):
values = self.options.option_values()
dict_value["options"] = [value.to_dict() for value in values]
else:
dict_value["options"] = [value.to_dict() for value in self.options]
dict_value["options"] = [
value.to_dict() if not isinstance(value, dict) else value
for value in self.options
]
if self.ui:
dict_value["ui"] = self.ui.to_dict()
@@ -594,6 +615,17 @@ class Parameter(TypeMetadata, Serializable):
value = view_value
return {self.name: value}
def new(self: TM) -> TM:
"""Copy the metadata."""
new_obj = self.__class__(
**self.model_dump(exclude_defaults=True, exclude={"ui", "options"})
)
if self.ui:
new_obj.ui = self.ui
if self.options:
new_obj.options = self.options
return new_obj
class BaseResource(Serializable, BaseModel):
"""The base resource."""
@@ -644,6 +676,17 @@ class IOField(Resource):
description="Whether current field is list",
examples=[True, False],
)
dynamic: bool = Field(
default=False,
description="Whether current field is dynamic",
examples=[True, False],
)
dynamic_minimum: int = Field(
default=0,
description="The minimum count of the dynamic field, only valid when dynamic is"
" True",
examples=[0, 1, 2],
)
@classmethod
def build_from(
@@ -653,6 +696,8 @@ class IOField(Resource):
type: Type,
description: Optional[str] = None,
is_list: bool = False,
dynamic: bool = False,
dynamic_minimum: int = 0,
):
"""Build the resource from the type."""
type_name = type.__qualname__
@@ -664,8 +709,22 @@ class IOField(Resource):
type_cls=type_cls,
is_list=is_list,
description=description or label,
dynamic=dynamic,
dynamic_minimum=dynamic_minimum,
)
@model_validator(mode="before")
@classmethod
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Pre fill the metadata."""
if not isinstance(values, dict):
return values
if "dynamic" not in values:
values["dynamic"] = False
if "dynamic_minimum" not in values:
values["dynamic_minimum"] = 0
return values
class BaseMetadata(BaseResource):
"""The base metadata."""
@@ -808,9 +867,40 @@ class BaseMetadata(BaseResource):
split_ids = self.id.split("_")
return "_".join(split_ids[:-1])
def _parse_ui_size(self) -> Optional[str]:
"""Parse the ui size."""
if not self.parameters:
return None
parameters_size = set()
for parameter in self.parameters:
if parameter.ui and parameter.ui.size:
parameters_size.add(parameter.ui.size)
for size in ["large", "middle", "small"]:
if size in parameters_size:
return size
return None
def to_dict(self) -> Dict:
"""Convert current metadata to json dict."""
from .ui import _size_to_order
dict_value = model_to_dict(self, exclude={"parameters"})
tags = dict_value.get("tags")
if not tags:
tags = {"ui_version": "flow2.0"}
elif isinstance(tags, dict) and "ui_version" not in tags:
tags["ui_version"] = "flow2.0"
parsed_ui_size = self._parse_ui_size()
if parsed_ui_size:
exist_size = tags.get("ui_size")
if not exist_size or _size_to_order(parsed_ui_size) > _size_to_order(
exist_size
):
# Use the higher order size as current size.
tags["ui_size"] = parsed_ui_size
dict_value["tags"] = tags
dict_value["parameters"] = [
parameter.to_dict() for parameter in self.parameters
]

View File

@@ -97,6 +97,12 @@ class FlowNodeData(BaseModel):
return ResourceMetadata(**value)
raise ValueError("Unable to infer the type for `data`")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
dict_value = model_to_dict(self, exclude={"data"})
dict_value["data"] = self.data.to_dict()
return dict_value
class FlowEdgeData(BaseModel):
"""Edge data in a flow."""
@@ -166,6 +172,12 @@ class FlowData(BaseModel):
edges: List[FlowEdgeData] = Field(..., description="Edges in the flow")
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
dict_value = model_to_dict(self, exclude={"nodes"})
dict_value["nodes"] = [n.to_dict() for n in self.nodes]
return dict_value
class _VariablesRequestBase(BaseModel):
key: str = Field(
@@ -518,9 +530,24 @@ class FlowPanel(BaseModel):
values["name"] = name
return values
def model_dump(self, **kwargs):
"""Override the model dump method."""
exclude = kwargs.get("exclude", set())
if "flow_dag" not in exclude:
exclude.add("flow_dag")
if "flow_data" not in exclude:
exclude.add("flow_data")
kwargs["exclude"] = exclude
common_dict = super().model_dump(**kwargs)
if self.flow_dag:
common_dict["flow_dag"] = None
if self.flow_data:
common_dict["flow_data"] = self.flow_data.to_dict()
return common_dict
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict."""
return model_to_dict(self, exclude={"flow_dag"})
return model_to_dict(self, exclude={"flow_dag", "flow_data"})
def get_variables_dict(self) -> List[Dict[str, Any]]:
"""Get the variables dict."""
@@ -568,6 +595,11 @@ class FlowFactory:
key_to_resource_nodes[key] = node
key_to_resource[key] = node.data
if not key_to_operator_nodes and not key_to_resource_nodes:
raise FlowMetadataException(
"No operator or resource nodes found in the flow."
)
for edge in flow_data.edges:
source_key = edge.source
target_key = edge.target
@@ -943,11 +975,17 @@ def fill_flow_panel(flow_panel: FlowPanel):
new_param = input_parameters[i.name]
i.label = new_param.label
i.description = new_param.description
i.dynamic = new_param.dynamic
i.is_list = new_param.is_list
i.dynamic_minimum = new_param.dynamic_minimum
for i in node.data.outputs:
if i.name in output_parameters:
new_param = output_parameters[i.name]
i.label = new_param.label
i.description = new_param.description
i.dynamic = new_param.dynamic
i.is_list = new_param.is_list
i.dynamic_minimum = new_param.dynamic_minimum
else:
data = cast(ResourceMetadata, node.data)
key = data.get_origin_id()
@@ -972,6 +1010,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
param.options = new_param.get_dict_options() # type: ignore
param.default = new_param.default
param.placeholder = new_param.placeholder
param.alias = new_param.alias
param.ui = new_param.ui
except (FlowException, ValueError) as e:
logger.warning(f"Unable to fill the flow panel: {e}")

View File

@@ -2,7 +2,7 @@
from typing import Any, Dict, List, Literal, Optional, Union
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict, model_validator
from dbgpt.core.interface.serialization import Serializable
from .exceptions import FlowUIComponentException
@@ -25,6 +25,16 @@ _UI_TYPE = Literal[
"code_editor",
]
_UI_SIZE_TYPE = Literal["large", "middle", "small"]
_SIZE_ORDER = {"large": 6, "middle": 4, "small": 2}
def _size_to_order(size: str) -> int:
"""Convert size to order."""
if size not in _SIZE_ORDER:
return -1
return _SIZE_ORDER[size]
class RefreshableMixin(BaseModel):
"""Refreshable mixin."""
@@ -81,6 +91,10 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel):
)
ui_type: _UI_TYPE = Field(..., description="UI component type")
size: Optional[_UI_SIZE_TYPE] = Field(
None,
description="The size of the component(small, middle, large)",
)
attr: Optional[UIAttribute] = Field(
None,
@@ -266,6 +280,27 @@ class UITextArea(PanelEditorMixin, UIInput):
description="The attributes of the component",
)
@model_validator(mode="after")
def check_size(self) -> "UITextArea":
"""Check the size.
Automatically set the size to large if the max_rows is greater than 10.
"""
attr = self.attr
auto_size = attr.auto_size if attr else None
if not attr or not auto_size or isinstance(auto_size, bool):
return self
max_rows = (
auto_size.max_rows
if isinstance(auto_size, self.UIAttribute.AutoSize)
else None
)
size = self.size
if not size and max_rows and max_rows > 10:
# Automatically set the size to large if the max_rows is greater than 10
self.size = "large"
return self
class UIAutoComplete(UIInput):
"""Auto complete component."""
@@ -450,7 +485,7 @@ class DefaultUITextArea(UITextArea):
attr: Optional[UITextArea.UIAttribute] = Field(
default_factory=lambda: UITextArea.UIAttribute(
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40)
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=20)
),
description="The attributes of the component",
)

View File

@@ -29,6 +29,7 @@ from dbgpt.util.tracer import root_tracer
from ..dag.base import DAG
from ..flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
@@ -965,6 +966,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
_PARAMETER_MEDIA_TYPE.new(),
_PARAMETER_STATUS_CODE.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(
@@ -1203,6 +1205,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
"User input parsed operator, parse the user input from request body and "
"return as a string"
),
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, key: str = "user_input", **kwargs):

View File

@@ -195,6 +195,9 @@ class ModelRequest:
temperature: Optional[float] = None
"""The temperature of the model inference."""
top_p: Optional[float] = None
"""The top p of the model inference."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""

View File

@@ -317,6 +317,25 @@ class ModelMessage(BaseModel):
"""
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
@staticmethod
def parse_user_message(messages: List[ModelMessage]) -> str:
"""Parse user message from messages.
Args:
messages (List[ModelMessage]): The all messages in the conversation.
Returns:
str: The user message
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
return lass_user_message
_SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
@@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
content=ai_message.content,
index=ai_message.index,
round_index=ai_message.round_index,
additional_kwargs=ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {},
additional_kwargs=(
ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {}
),
)
current_round.append(view_message)
return sum(messages_by_round, [])

View File

@@ -246,10 +246,16 @@ class BaseLLM:
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view"
def __init__(self, llm_client: Optional[LLMClient] = None):
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
):
"""Create a new LLM operator."""
self._llm_client = llm_client
self._save_model_output = save_model_output
@property
def llm_client(self) -> LLMClient:
@@ -262,9 +268,10 @@ class BaseLLM:
self, current_dag_context: DAGContext, model_output: ModelOutput
) -> None:
"""Save the model output to the share data."""
await current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
if self._save_model_output:
await current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
@@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
This operator will generate a no streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
"""Create a new LLM operator."""
super().__init__(llm_client=llm_client)
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput:
@@ -309,13 +321,18 @@ class BaseStreamingLLMOperator(
This operator will generate streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
"""Create a streaming operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
"""
super().__init__(llm_client=llm_client)
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
BaseOperator.__init__(self, **kwargs)
async def streamify( # type: ignore

View File

@@ -4,14 +4,10 @@ from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import model_validator
from dbgpt.core import (
ModelMessage,
ModelMessageRoleType,
ModelOutput,
StorageConversation,
)
from dbgpt.core import ModelMessage, ModelOutput, StorageConversation
from dbgpt.core.awel import JoinOperator, MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
@@ -42,6 +38,7 @@ from dbgpt.util.i18n_utils import _
name="common_chat_prompt_template",
category=ResourceCategory.PROMPT,
description=_("The operator to build the prompt with static prompt."),
tags={"order": TAGS_ORDER_HIGH},
parameters=[
Parameter.build_from(
label=_("System Message"),
@@ -101,9 +98,10 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
"""The base prompt builder operator."""
def __init__(self, check_storage: bool, **kwargs):
def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs):
"""Create a new prompt builder operator."""
super().__init__(check_storage=check_storage, **kwargs)
self._save_to_storage = save_to_storage
async def format_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
@@ -122,8 +120,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
messages = prompt.format_messages(**pass_kwargs)
model_messages = ModelMessage.from_base_messages(messages)
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(model_messages)
if self._save_to_storage:
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(model_messages)
return model_messages
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
@@ -132,13 +131,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
Args:
messages (List[ModelMessage]): The messages.
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
lass_user_message = ModelMessage.parse_user_message(messages)
storage_conv: Optional[
StorageConversation
] = await self.get_storage_conversation()
@@ -150,6 +143,8 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
async def after_dag_end(self, event_loop_task_id: int):
"""Execute after the DAG finished."""
if not self._save_to_storage:
return
# Save the storage conversation to storage after the whole DAG finished
storage_conv: Optional[
StorageConversation
@@ -422,7 +417,7 @@ class HistoryPromptBuilderOperator(
self._prompt = prompt
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
@@ -455,7 +450,7 @@ class HistoryDynamicPromptBuilderOperator(
"""Create a new history dynamic prompt builder operator."""
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type

View File

@@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
from dbgpt.core import ModelOutput
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
ViewMetadata,
)
from dbgpt.util.i18n_utils import _
T = TypeVar("T")
@@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
if self.current_dag_context.streaming_call:
return self.parse_model_stream_resp_ex(input_value, 0)
else:
return self.parse_model_nostream_resp(input_value, "###")
return self.parse_model_nostream_resp(input_value, "#####################")
def _parse_model_response(response: ResponseTye):
@@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
class SQLOutputParser(BaseOutputParser):
"""Parse the SQL output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL Output Parser"),
name="default_sql_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_("Parse the SQL output of an LLM call."),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("Dict SQL Output"),
"dict",
dict,
description=_("The dict output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
@@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True)
class SQLListOutputParser(BaseOutputParser):
"""Parse the SQL list output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL List Output Parser"),
name="default_sql_list_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_(
"Parse the SQL list output of an LLM call, mostly used for dashboard."
),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("List SQL Output"),
"list",
dict,
is_list=True,
description=_("The list output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL list output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
from dbgpt.util.json_utils import find_json_objects
model_out_text = super().parse_model_nostream_resp(response, sep)
json_objects = find_json_objects(model_out_text)
json_count = len(json_objects)
if json_count < 1:
raise ValueError("Unable to obtain valid output.")
parsed_json_list = json_objects[0]
if not isinstance(parsed_json_list, list):
if isinstance(parsed_json_list, dict):
return [parsed_json_list]
else:
raise ValueError("Invalid output format.")
return parsed_json_list

View File

@@ -254,6 +254,18 @@ class ChatPromptTemplate(BasePromptTemplate):
values["input_variables"] = sorted(input_variables)
return values
def get_placeholders(self) -> List[str]:
"""Get all placeholders in the prompt template.
Returns:
List[str]: The placeholders.
"""
placeholders = set()
for message in self.messages:
if isinstance(message, MessagesPlaceholder):
placeholders.add(message.variable_name)
return sorted(placeholders)
@dataclasses.dataclass
class PromptTemplateIdentifier(ResourceIdentifier):