mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-27 20:38:30 +00:00
chore: Merge latest code
This commit is contained in:
@@ -10,6 +10,7 @@ from ..util.parameter_util import ( # noqa: F401
|
||||
VariablesDynamicOptions,
|
||||
)
|
||||
from .base import ( # noqa: F401
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@@ -33,6 +34,7 @@ __ALL__ = [
|
||||
"ResourceCategory",
|
||||
"ResourceType",
|
||||
"OperatorType",
|
||||
"TAGS_ORDER_HIGH",
|
||||
"IOField",
|
||||
"BaseDynamicOptions",
|
||||
"FunctionDynamicOptions",
|
||||
|
@@ -40,6 +40,9 @@ _BASIC_TYPES = [str, int, float, bool, dict, list, set]
|
||||
T = TypeVar("T", bound="ViewMixin")
|
||||
TM = TypeVar("TM", bound="TypeMetadata")
|
||||
|
||||
TAGS_ORDER_HIGH = "higher-order"
|
||||
TAGS_ORDER_FIRST = "first-order"
|
||||
|
||||
|
||||
def _get_type_name(type_: Type[Any]) -> str:
|
||||
"""Get the type name of the type.
|
||||
@@ -143,6 +146,8 @@ _OPERATOR_CATEGORY_DETAIL = {
|
||||
"agent": _CategoryDetail("Agent", "The agent operator"),
|
||||
"rag": _CategoryDetail("RAG", "The RAG operator"),
|
||||
"experimental": _CategoryDetail("EXPERIMENTAL", "EXPERIMENTAL operator"),
|
||||
"database": _CategoryDetail("Database", "Interact with the database"),
|
||||
"type_converter": _CategoryDetail("Type Converter", "Convert the type"),
|
||||
"example": _CategoryDetail("Example", "Example operator"),
|
||||
}
|
||||
|
||||
@@ -159,6 +164,8 @@ class OperatorCategory(str, Enum):
|
||||
AGENT = "agent"
|
||||
RAG = "rag"
|
||||
EXPERIMENTAL = "experimental"
|
||||
DATABASE = "database"
|
||||
TYPE_CONVERTER = "type_converter"
|
||||
EXAMPLE = "example"
|
||||
|
||||
def label(self) -> str:
|
||||
@@ -202,6 +209,7 @@ _RESOURCE_CATEGORY_DETAIL = {
|
||||
"embeddings": _CategoryDetail("Embeddings", "The embeddings resource"),
|
||||
"rag": _CategoryDetail("RAG", "The resource"),
|
||||
"vector_store": _CategoryDetail("Vector Store", "The vector store resource"),
|
||||
"database": _CategoryDetail("Database", "Interact with the database"),
|
||||
"example": _CategoryDetail("Example", "The example resource"),
|
||||
}
|
||||
|
||||
@@ -219,6 +227,7 @@ class ResourceCategory(str, Enum):
|
||||
EMBEDDINGS = "embeddings"
|
||||
RAG = "rag"
|
||||
VECTOR_STORE = "vector_store"
|
||||
DATABASE = "database"
|
||||
EXAMPLE = "example"
|
||||
|
||||
def label(self) -> str:
|
||||
@@ -372,32 +381,41 @@ class Parameter(TypeMetadata, Serializable):
|
||||
"value": values.get("value"),
|
||||
"default": values.get("default"),
|
||||
}
|
||||
is_list = values.get("is_list") or False
|
||||
if type_cls:
|
||||
for k, v in to_handle_values.items():
|
||||
if v:
|
||||
handled_v = cls._covert_to_real_type(type_cls, v)
|
||||
handled_v = cls._covert_to_real_type(type_cls, v, is_list)
|
||||
values[k] = handled_v
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def _covert_to_real_type(cls, type_cls: str, v: Any) -> Any:
|
||||
if type_cls and v is not None:
|
||||
typed_value: Any = v
|
||||
def _covert_to_real_type(cls, type_cls: str, v: Any, is_list: bool) -> Any:
|
||||
def _parse_single_value(vv: Any) -> Any:
|
||||
typed_value: Any = vv
|
||||
try:
|
||||
# Try to convert the value to the type.
|
||||
if type_cls == "builtins.str":
|
||||
typed_value = str(v)
|
||||
typed_value = str(vv)
|
||||
elif type_cls == "builtins.int":
|
||||
typed_value = int(v)
|
||||
typed_value = int(vv)
|
||||
elif type_cls == "builtins.float":
|
||||
typed_value = float(v)
|
||||
typed_value = float(vv)
|
||||
elif type_cls == "builtins.bool":
|
||||
if str(v).lower() in ["false", "0", "", "no", "off"]:
|
||||
if str(vv).lower() in ["false", "0", "", "no", "off"]:
|
||||
return False
|
||||
typed_value = bool(v)
|
||||
typed_value = bool(vv)
|
||||
return typed_value
|
||||
except ValueError:
|
||||
raise ValidationError(f"Value '{v}' is not valid for type {type_cls}")
|
||||
raise ValidationError(f"Value '{vv}' is not valid for type {type_cls}")
|
||||
|
||||
if type_cls and v is not None:
|
||||
if not is_list:
|
||||
_parse_single_value(v)
|
||||
else:
|
||||
if not isinstance(v, list):
|
||||
raise ValidationError(f"Value '{v}' is not a list.")
|
||||
return [_parse_single_value(vv) for vv in v]
|
||||
return v
|
||||
|
||||
def get_typed_value(self) -> Any:
|
||||
@@ -413,11 +431,11 @@ class Parameter(TypeMetadata, Serializable):
|
||||
if is_variables and self.value is not None and isinstance(self.value, str):
|
||||
return VariablesPlaceHolder(self.name, self.value)
|
||||
else:
|
||||
return self._covert_to_real_type(self.type_cls, self.value)
|
||||
return self._covert_to_real_type(self.type_cls, self.value, self.is_list)
|
||||
|
||||
def get_typed_default(self) -> Any:
|
||||
"""Get the typed default."""
|
||||
return self._covert_to_real_type(self.type_cls, self.default)
|
||||
return self._covert_to_real_type(self.type_cls, self.default, self.is_list)
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@@ -499,7 +517,10 @@ class Parameter(TypeMetadata, Serializable):
|
||||
values = self.options.option_values()
|
||||
dict_value["options"] = [value.to_dict() for value in values]
|
||||
else:
|
||||
dict_value["options"] = [value.to_dict() for value in self.options]
|
||||
dict_value["options"] = [
|
||||
value.to_dict() if not isinstance(value, dict) else value
|
||||
for value in self.options
|
||||
]
|
||||
|
||||
if self.ui:
|
||||
dict_value["ui"] = self.ui.to_dict()
|
||||
@@ -594,6 +615,17 @@ class Parameter(TypeMetadata, Serializable):
|
||||
value = view_value
|
||||
return {self.name: value}
|
||||
|
||||
def new(self: TM) -> TM:
|
||||
"""Copy the metadata."""
|
||||
new_obj = self.__class__(
|
||||
**self.model_dump(exclude_defaults=True, exclude={"ui", "options"})
|
||||
)
|
||||
if self.ui:
|
||||
new_obj.ui = self.ui
|
||||
if self.options:
|
||||
new_obj.options = self.options
|
||||
return new_obj
|
||||
|
||||
|
||||
class BaseResource(Serializable, BaseModel):
|
||||
"""The base resource."""
|
||||
@@ -644,6 +676,17 @@ class IOField(Resource):
|
||||
description="Whether current field is list",
|
||||
examples=[True, False],
|
||||
)
|
||||
dynamic: bool = Field(
|
||||
default=False,
|
||||
description="Whether current field is dynamic",
|
||||
examples=[True, False],
|
||||
)
|
||||
dynamic_minimum: int = Field(
|
||||
default=0,
|
||||
description="The minimum count of the dynamic field, only valid when dynamic is"
|
||||
" True",
|
||||
examples=[0, 1, 2],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_from(
|
||||
@@ -653,6 +696,8 @@ class IOField(Resource):
|
||||
type: Type,
|
||||
description: Optional[str] = None,
|
||||
is_list: bool = False,
|
||||
dynamic: bool = False,
|
||||
dynamic_minimum: int = 0,
|
||||
):
|
||||
"""Build the resource from the type."""
|
||||
type_name = type.__qualname__
|
||||
@@ -664,8 +709,22 @@ class IOField(Resource):
|
||||
type_cls=type_cls,
|
||||
is_list=is_list,
|
||||
description=description or label,
|
||||
dynamic=dynamic,
|
||||
dynamic_minimum=dynamic_minimum,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the metadata."""
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
if "dynamic" not in values:
|
||||
values["dynamic"] = False
|
||||
if "dynamic_minimum" not in values:
|
||||
values["dynamic_minimum"] = 0
|
||||
return values
|
||||
|
||||
|
||||
class BaseMetadata(BaseResource):
|
||||
"""The base metadata."""
|
||||
@@ -808,9 +867,40 @@ class BaseMetadata(BaseResource):
|
||||
split_ids = self.id.split("_")
|
||||
return "_".join(split_ids[:-1])
|
||||
|
||||
def _parse_ui_size(self) -> Optional[str]:
|
||||
"""Parse the ui size."""
|
||||
if not self.parameters:
|
||||
return None
|
||||
parameters_size = set()
|
||||
for parameter in self.parameters:
|
||||
if parameter.ui and parameter.ui.size:
|
||||
parameters_size.add(parameter.ui.size)
|
||||
for size in ["large", "middle", "small"]:
|
||||
if size in parameters_size:
|
||||
return size
|
||||
return None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert current metadata to json dict."""
|
||||
from .ui import _size_to_order
|
||||
|
||||
dict_value = model_to_dict(self, exclude={"parameters"})
|
||||
tags = dict_value.get("tags")
|
||||
if not tags:
|
||||
tags = {"ui_version": "flow2.0"}
|
||||
elif isinstance(tags, dict) and "ui_version" not in tags:
|
||||
tags["ui_version"] = "flow2.0"
|
||||
|
||||
parsed_ui_size = self._parse_ui_size()
|
||||
if parsed_ui_size:
|
||||
exist_size = tags.get("ui_size")
|
||||
if not exist_size or _size_to_order(parsed_ui_size) > _size_to_order(
|
||||
exist_size
|
||||
):
|
||||
# Use the higher order size as current size.
|
||||
tags["ui_size"] = parsed_ui_size
|
||||
|
||||
dict_value["tags"] = tags
|
||||
dict_value["parameters"] = [
|
||||
parameter.to_dict() for parameter in self.parameters
|
||||
]
|
||||
|
@@ -97,6 +97,12 @@ class FlowNodeData(BaseModel):
|
||||
return ResourceMetadata(**value)
|
||||
raise ValueError("Unable to infer the type for `data`")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
dict_value = model_to_dict(self, exclude={"data"})
|
||||
dict_value["data"] = self.data.to_dict()
|
||||
return dict_value
|
||||
|
||||
|
||||
class FlowEdgeData(BaseModel):
|
||||
"""Edge data in a flow."""
|
||||
@@ -166,6 +172,12 @@ class FlowData(BaseModel):
|
||||
edges: List[FlowEdgeData] = Field(..., description="Edges in the flow")
|
||||
viewport: FlowPositionData = Field(..., description="Viewport of the flow")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
dict_value = model_to_dict(self, exclude={"nodes"})
|
||||
dict_value["nodes"] = [n.to_dict() for n in self.nodes]
|
||||
return dict_value
|
||||
|
||||
|
||||
class _VariablesRequestBase(BaseModel):
|
||||
key: str = Field(
|
||||
@@ -518,9 +530,24 @@ class FlowPanel(BaseModel):
|
||||
values["name"] = name
|
||||
return values
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override the model dump method."""
|
||||
exclude = kwargs.get("exclude", set())
|
||||
if "flow_dag" not in exclude:
|
||||
exclude.add("flow_dag")
|
||||
if "flow_data" not in exclude:
|
||||
exclude.add("flow_data")
|
||||
kwargs["exclude"] = exclude
|
||||
common_dict = super().model_dump(**kwargs)
|
||||
if self.flow_dag:
|
||||
common_dict["flow_dag"] = None
|
||||
if self.flow_data:
|
||||
common_dict["flow_data"] = self.flow_data.to_dict()
|
||||
return common_dict
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict."""
|
||||
return model_to_dict(self, exclude={"flow_dag"})
|
||||
return model_to_dict(self, exclude={"flow_dag", "flow_data"})
|
||||
|
||||
def get_variables_dict(self) -> List[Dict[str, Any]]:
|
||||
"""Get the variables dict."""
|
||||
@@ -568,6 +595,11 @@ class FlowFactory:
|
||||
key_to_resource_nodes[key] = node
|
||||
key_to_resource[key] = node.data
|
||||
|
||||
if not key_to_operator_nodes and not key_to_resource_nodes:
|
||||
raise FlowMetadataException(
|
||||
"No operator or resource nodes found in the flow."
|
||||
)
|
||||
|
||||
for edge in flow_data.edges:
|
||||
source_key = edge.source
|
||||
target_key = edge.target
|
||||
@@ -943,11 +975,17 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
new_param = input_parameters[i.name]
|
||||
i.label = new_param.label
|
||||
i.description = new_param.description
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
for i in node.data.outputs:
|
||||
if i.name in output_parameters:
|
||||
new_param = output_parameters[i.name]
|
||||
i.label = new_param.label
|
||||
i.description = new_param.description
|
||||
i.dynamic = new_param.dynamic
|
||||
i.is_list = new_param.is_list
|
||||
i.dynamic_minimum = new_param.dynamic_minimum
|
||||
else:
|
||||
data = cast(ResourceMetadata, node.data)
|
||||
key = data.get_origin_id()
|
||||
@@ -972,6 +1010,8 @@ def fill_flow_panel(flow_panel: FlowPanel):
|
||||
param.options = new_param.get_dict_options() # type: ignore
|
||||
param.default = new_param.default
|
||||
param.placeholder = new_param.placeholder
|
||||
param.alias = new_param.alias
|
||||
param.ui = new_param.ui
|
||||
|
||||
except (FlowException, ValueError) as e:
|
||||
logger.warning(f"Unable to fill the flow panel: {e}")
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict
|
||||
from dbgpt._private.pydantic import BaseModel, Field, model_to_dict, model_validator
|
||||
from dbgpt.core.interface.serialization import Serializable
|
||||
|
||||
from .exceptions import FlowUIComponentException
|
||||
@@ -25,6 +25,16 @@ _UI_TYPE = Literal[
|
||||
"code_editor",
|
||||
]
|
||||
|
||||
_UI_SIZE_TYPE = Literal["large", "middle", "small"]
|
||||
_SIZE_ORDER = {"large": 6, "middle": 4, "small": 2}
|
||||
|
||||
|
||||
def _size_to_order(size: str) -> int:
|
||||
"""Convert size to order."""
|
||||
if size not in _SIZE_ORDER:
|
||||
return -1
|
||||
return _SIZE_ORDER[size]
|
||||
|
||||
|
||||
class RefreshableMixin(BaseModel):
|
||||
"""Refreshable mixin."""
|
||||
@@ -81,6 +91,10 @@ class UIComponent(RefreshableMixin, Serializable, BaseModel):
|
||||
)
|
||||
|
||||
ui_type: _UI_TYPE = Field(..., description="UI component type")
|
||||
size: Optional[_UI_SIZE_TYPE] = Field(
|
||||
None,
|
||||
description="The size of the component(small, middle, large)",
|
||||
)
|
||||
|
||||
attr: Optional[UIAttribute] = Field(
|
||||
None,
|
||||
@@ -266,6 +280,27 @@ class UITextArea(PanelEditorMixin, UIInput):
|
||||
description="The attributes of the component",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_size(self) -> "UITextArea":
|
||||
"""Check the size.
|
||||
|
||||
Automatically set the size to large if the max_rows is greater than 10.
|
||||
"""
|
||||
attr = self.attr
|
||||
auto_size = attr.auto_size if attr else None
|
||||
if not attr or not auto_size or isinstance(auto_size, bool):
|
||||
return self
|
||||
max_rows = (
|
||||
auto_size.max_rows
|
||||
if isinstance(auto_size, self.UIAttribute.AutoSize)
|
||||
else None
|
||||
)
|
||||
size = self.size
|
||||
if not size and max_rows and max_rows > 10:
|
||||
# Automatically set the size to large if the max_rows is greater than 10
|
||||
self.size = "large"
|
||||
return self
|
||||
|
||||
|
||||
class UIAutoComplete(UIInput):
|
||||
"""Auto complete component."""
|
||||
@@ -450,7 +485,7 @@ class DefaultUITextArea(UITextArea):
|
||||
|
||||
attr: Optional[UITextArea.UIAttribute] = Field(
|
||||
default_factory=lambda: UITextArea.UIAttribute(
|
||||
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=40)
|
||||
auto_size=UITextArea.UIAttribute.AutoSize(min_rows=2, max_rows=20)
|
||||
),
|
||||
description="The attributes of the component",
|
||||
)
|
||||
|
@@ -29,6 +29,7 @@ from dbgpt.util.tracer import root_tracer
|
||||
|
||||
from ..dag.base import DAG
|
||||
from ..flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@@ -965,6 +966,7 @@ class CommonLLMHttpTrigger(HttpTrigger):
|
||||
_PARAMETER_MEDIA_TYPE.new(),
|
||||
_PARAMETER_STATUS_CODE.new(),
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -1203,6 +1205,7 @@ class RequestedParsedOperator(MapOperator[CommonLLMHttpRequestBody, str]):
|
||||
"User input parsed operator, parse the user input from request body and "
|
||||
"return as a string"
|
||||
),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, key: str = "user_input", **kwargs):
|
||||
|
@@ -195,6 +195,9 @@ class ModelRequest:
|
||||
temperature: Optional[float] = None
|
||||
"""The temperature of the model inference."""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""The top p of the model inference."""
|
||||
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
|
||||
|
@@ -317,6 +317,25 @@ class ModelMessage(BaseModel):
|
||||
"""
|
||||
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
|
||||
|
||||
@staticmethod
|
||||
def parse_user_message(messages: List[ModelMessage]) -> str:
|
||||
"""Parse user message from messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The all messages in the conversation.
|
||||
|
||||
Returns:
|
||||
str: The user message
|
||||
"""
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
return lass_user_message
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
|
||||
@@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
additional_kwargs=(
|
||||
ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {}
|
||||
),
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@@ -246,10 +246,16 @@ class BaseLLM:
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
):
|
||||
"""Create a new LLM operator."""
|
||||
self._llm_client = llm_client
|
||||
self._save_model_output = save_model_output
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
@@ -262,9 +268,10 @@ class BaseLLM:
|
||||
self, current_dag_context: DAGContext, model_output: ModelOutput
|
||||
) -> None:
|
||||
"""Save the model output to the share data."""
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
if self._save_model_output:
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
|
||||
|
||||
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
@@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new LLM operator."""
|
||||
super().__init__(llm_client=llm_client)
|
||||
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
@@ -309,13 +321,18 @@ class BaseStreamingLLMOperator(
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
"""
|
||||
super().__init__(llm_client=llm_client)
|
||||
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
|
||||
BaseOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify( # type: ignore
|
||||
|
@@ -4,14 +4,10 @@ from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import model_validator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core import ModelMessage, ModelOutput, StorageConversation
|
||||
from dbgpt.core.awel import JoinOperator, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@@ -42,6 +38,7 @@ from dbgpt.util.i18n_utils import _
|
||||
name="common_chat_prompt_template",
|
||||
category=ResourceCategory.PROMPT,
|
||||
description=_("The operator to build the prompt with static prompt."),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label=_("System Message"),
|
||||
@@ -101,9 +98,10 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
"""The base prompt builder operator."""
|
||||
|
||||
def __init__(self, check_storage: bool, **kwargs):
|
||||
def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs):
|
||||
"""Create a new prompt builder operator."""
|
||||
super().__init__(check_storage=check_storage, **kwargs)
|
||||
self._save_to_storage = save_to_storage
|
||||
|
||||
async def format_prompt(
|
||||
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||
@@ -122,8 +120,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(model_messages)
|
||||
if self._save_to_storage:
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(model_messages)
|
||||
return model_messages
|
||||
|
||||
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
|
||||
@@ -132,13 +131,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
"""
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
lass_user_message = ModelMessage.parse_user_message(messages)
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
] = await self.get_storage_conversation()
|
||||
@@ -150,6 +143,8 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
|
||||
async def after_dag_end(self, event_loop_task_id: int):
|
||||
"""Execute after the DAG finished."""
|
||||
if not self._save_to_storage:
|
||||
return
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
@@ -422,7 +417,7 @@ class HistoryPromptBuilderOperator(
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
@@ -455,7 +450,7 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
"""Create a new history dynamic prompt builder operator."""
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
|
@@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
if self.current_dag_context.streaming_call:
|
||||
return self.parse_model_stream_resp_ex(input_value, 0)
|
||||
else:
|
||||
return self.parse_model_nostream_resp(input_value, "###")
|
||||
return self.parse_model_nostream_resp(input_value, "#####################")
|
||||
|
||||
|
||||
def _parse_model_response(response: ResponseTye):
|
||||
@@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL Output Parser"),
|
||||
name="default_sql_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_("Parse the SQL output of an LLM call."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Dict SQL Output"),
|
||||
"dict",
|
||||
dict,
|
||||
description=_("The dict output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
@@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
||||
|
||||
class SQLListOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL list output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL List Output Parser"),
|
||||
name="default_sql_list_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_(
|
||||
"Parse the SQL list output of an LLM call, mostly used for dashboard."
|
||||
),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("List SQL Output"),
|
||||
"list",
|
||||
dict,
|
||||
is_list=True,
|
||||
description=_("The list output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL list output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
json_objects = find_json_objects(model_out_text)
|
||||
json_count = len(json_objects)
|
||||
if json_count < 1:
|
||||
raise ValueError("Unable to obtain valid output.")
|
||||
|
||||
parsed_json_list = json_objects[0]
|
||||
if not isinstance(parsed_json_list, list):
|
||||
if isinstance(parsed_json_list, dict):
|
||||
return [parsed_json_list]
|
||||
else:
|
||||
raise ValueError("Invalid output format.")
|
||||
return parsed_json_list
|
||||
|
@@ -254,6 +254,18 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
|
||||
def get_placeholders(self) -> List[str]:
|
||||
"""Get all placeholders in the prompt template.
|
||||
|
||||
Returns:
|
||||
List[str]: The placeholders.
|
||||
"""
|
||||
placeholders = set()
|
||||
for message in self.messages:
|
||||
if isinstance(message, MessagesPlaceholder):
|
||||
placeholders.add(message.variable_name)
|
||||
return sorted(placeholders)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
|
Reference in New Issue
Block a user