mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-08 20:39:44 +00:00
feat(core): Support higher-order operators (#1984)
Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
@@ -195,6 +195,9 @@ class ModelRequest:
|
||||
temperature: Optional[float] = None
|
||||
"""The temperature of the model inference."""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""The top p of the model inference."""
|
||||
|
||||
max_new_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate."""
|
||||
|
||||
|
@@ -317,6 +317,25 @@ class ModelMessage(BaseModel):
|
||||
"""
|
||||
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
|
||||
|
||||
@staticmethod
|
||||
def parse_user_message(messages: List[ModelMessage]) -> str:
|
||||
"""Parse user message from messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The all messages in the conversation.
|
||||
|
||||
Returns:
|
||||
str: The user message
|
||||
"""
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
return lass_user_message
|
||||
|
||||
|
||||
_SingleRoundMessage = List[BaseMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
|
||||
@@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
additional_kwargs=(
|
||||
ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {}
|
||||
),
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@@ -246,10 +246,16 @@ class BaseLLM:
|
||||
|
||||
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
|
||||
SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view"
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
):
|
||||
"""Create a new LLM operator."""
|
||||
self._llm_client = llm_client
|
||||
self._save_model_output = save_model_output
|
||||
|
||||
@property
|
||||
def llm_client(self) -> LLMClient:
|
||||
@@ -262,9 +268,10 @@ class BaseLLM:
|
||||
self, current_dag_context: DAGContext, model_output: ModelOutput
|
||||
) -> None:
|
||||
"""Save the model output to the share data."""
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
if self._save_model_output:
|
||||
await current_dag_context.save_to_share_data(
|
||||
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
|
||||
)
|
||||
|
||||
|
||||
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
@@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
|
||||
This operator will generate a no streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new LLM operator."""
|
||||
super().__init__(llm_client=llm_client)
|
||||
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
|
||||
async def map(self, request: ModelRequest) -> ModelOutput:
|
||||
@@ -309,13 +321,18 @@ class BaseStreamingLLMOperator(
|
||||
This operator will generate streaming response.
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: Optional[LLMClient] = None,
|
||||
save_model_output: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a streaming operator for a LLM.
|
||||
|
||||
Args:
|
||||
llm_client (LLMClient, optional): The LLM client. Defaults to None.
|
||||
"""
|
||||
super().__init__(llm_client=llm_client)
|
||||
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
|
||||
BaseOperator.__init__(self, **kwargs)
|
||||
|
||||
async def streamify( # type: ignore
|
||||
|
@@ -4,14 +4,10 @@ from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import model_validator
|
||||
from dbgpt.core import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core import ModelMessage, ModelOutput, StorageConversation
|
||||
from dbgpt.core.awel import JoinOperator, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
@@ -42,6 +38,7 @@ from dbgpt.util.i18n_utils import _
|
||||
name="common_chat_prompt_template",
|
||||
category=ResourceCategory.PROMPT,
|
||||
description=_("The operator to build the prompt with static prompt."),
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label=_("System Message"),
|
||||
@@ -101,9 +98,10 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
"""The base prompt builder operator."""
|
||||
|
||||
def __init__(self, check_storage: bool, **kwargs):
|
||||
def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs):
|
||||
"""Create a new prompt builder operator."""
|
||||
super().__init__(check_storage=check_storage, **kwargs)
|
||||
self._save_to_storage = save_to_storage
|
||||
|
||||
async def format_prompt(
|
||||
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
|
||||
@@ -122,8 +120,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
|
||||
messages = prompt.format_messages(**pass_kwargs)
|
||||
model_messages = ModelMessage.from_base_messages(messages)
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(model_messages)
|
||||
if self._save_to_storage:
|
||||
# Start new round conversation, and save user message to storage
|
||||
await self.start_new_round_conv(model_messages)
|
||||
return model_messages
|
||||
|
||||
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
|
||||
@@ -132,13 +131,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
"""
|
||||
lass_user_message = None
|
||||
for message in messages[::-1]:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
lass_user_message = message.content
|
||||
break
|
||||
if not lass_user_message:
|
||||
raise ValueError("No user message")
|
||||
lass_user_message = ModelMessage.parse_user_message(messages)
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
] = await self.get_storage_conversation()
|
||||
@@ -150,6 +143,8 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
|
||||
async def after_dag_end(self, event_loop_task_id: int):
|
||||
"""Execute after the DAG finished."""
|
||||
if not self._save_to_storage:
|
||||
return
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
storage_conv: Optional[
|
||||
StorageConversation
|
||||
@@ -422,7 +417,7 @@ class HistoryPromptBuilderOperator(
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
@@ -455,7 +450,7 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
"""Create a new history dynamic prompt builder operator."""
|
||||
self._history_key = history_key
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
|
@@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
from dbgpt.core.awel.flow import (
|
||||
TAGS_ORDER_HIGH,
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
if self.current_dag_context.streaming_call:
|
||||
return self.parse_model_stream_resp_ex(input_value, 0)
|
||||
else:
|
||||
return self.parse_model_nostream_resp(input_value, "###")
|
||||
return self.parse_model_nostream_resp(input_value, "#####################")
|
||||
|
||||
|
||||
def _parse_model_response(response: ResponseTye):
|
||||
@@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
|
||||
class SQLOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL Output Parser"),
|
||||
name="default_sql_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_("Parse the SQL output of an LLM call."),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("Dict SQL Output"),
|
||||
"dict",
|
||||
dict,
|
||||
description=_("The dict output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
@@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
clean_str = super().parse_prompt_response(model_out_text)
|
||||
return json.loads(clean_str, strict=True)
|
||||
|
||||
|
||||
class SQLListOutputParser(BaseOutputParser):
|
||||
"""Parse the SQL list output of an LLM call."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label=_("SQL List Output Parser"),
|
||||
name="default_sql_list_output_parser",
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description=_(
|
||||
"Parse the SQL list output of an LLM call, mostly used for dashboard."
|
||||
),
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
_("Model Output"),
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
description=_("The model output of upstream."),
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
_("List SQL Output"),
|
||||
"list",
|
||||
dict,
|
||||
is_list=True,
|
||||
description=_("The list output after parsing."),
|
||||
)
|
||||
],
|
||||
tags={"order": TAGS_ORDER_HIGH},
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = False, **kwargs):
|
||||
"""Create a new SQL list output parser."""
|
||||
super().__init__(is_stream_out=is_stream_out, **kwargs)
|
||||
|
||||
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
|
||||
"""Parse the output of an LLM call."""
|
||||
from dbgpt.util.json_utils import find_json_objects
|
||||
|
||||
model_out_text = super().parse_model_nostream_resp(response, sep)
|
||||
json_objects = find_json_objects(model_out_text)
|
||||
json_count = len(json_objects)
|
||||
if json_count < 1:
|
||||
raise ValueError("Unable to obtain valid output.")
|
||||
|
||||
parsed_json_list = json_objects[0]
|
||||
if not isinstance(parsed_json_list, list):
|
||||
if isinstance(parsed_json_list, dict):
|
||||
return [parsed_json_list]
|
||||
else:
|
||||
raise ValueError("Invalid output format.")
|
||||
return parsed_json_list
|
||||
|
@@ -254,6 +254,18 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
values["input_variables"] = sorted(input_variables)
|
||||
return values
|
||||
|
||||
def get_placeholders(self) -> List[str]:
|
||||
"""Get all placeholders in the prompt template.
|
||||
|
||||
Returns:
|
||||
List[str]: The placeholders.
|
||||
"""
|
||||
placeholders = set()
|
||||
for message in self.messages:
|
||||
if isinstance(message, MessagesPlaceholder):
|
||||
placeholders.add(message.variable_name)
|
||||
return sorted(placeholders)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PromptTemplateIdentifier(ResourceIdentifier):
|
||||
|
@@ -31,6 +31,7 @@ BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
|
||||
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
|
||||
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
|
||||
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
|
||||
# Not implemented yet
|
||||
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
|
||||
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
|
||||
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
|
||||
@@ -373,6 +374,15 @@ class VariablesProvider(BaseComponent, ABC):
|
||||
) -> Any:
|
||||
"""Query variables from storage."""
|
||||
|
||||
async def async_get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage async."""
|
||||
raise NotImplementedError("Current variables provider does not support async.")
|
||||
|
||||
@abstractmethod
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
@@ -456,6 +466,24 @@ class VariablesPlaceHolder:
|
||||
return None
|
||||
raise e
|
||||
|
||||
async def async_parse(
|
||||
self,
|
||||
variables_provider: VariablesProvider,
|
||||
ignore_not_found_error: bool = False,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Parse the variables async."""
|
||||
try:
|
||||
return await variables_provider.async_get(
|
||||
self.full_key,
|
||||
self.default_value,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
except ValueError as e:
|
||||
if ignore_not_found_error:
|
||||
return None
|
||||
raise e
|
||||
|
||||
def __repr__(self):
|
||||
"""Return the representation of the variables place holder."""
|
||||
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
|
||||
@@ -507,6 +535,42 @@ class StorageVariablesProvider(VariablesProvider):
|
||||
variable.value = self.encryption.decrypt(variable.value, variable.salt)
|
||||
return self._convert_to_value_type(variable)
|
||||
|
||||
async def async_get(
|
||||
self,
|
||||
full_key: str,
|
||||
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
|
||||
default_identifier_map: Optional[Dict[str, str]] = None,
|
||||
) -> Any:
|
||||
"""Query variables from storage async."""
|
||||
# Try to get variables from storage
|
||||
value = await blocking_func_to_async_no_executor(
|
||||
self.get,
|
||||
full_key,
|
||||
default_value=None,
|
||||
default_identifier_map=default_identifier_map,
|
||||
)
|
||||
if value is not None:
|
||||
return value
|
||||
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
|
||||
# Get all builtin variables
|
||||
variables = await self.async_get_variables(
|
||||
key=key.key,
|
||||
scope=key.scope,
|
||||
scope_key=key.scope_key,
|
||||
sys_code=key.sys_code,
|
||||
user_name=key.user_name,
|
||||
)
|
||||
values = [v for v in variables if v.name == key.name]
|
||||
if not values:
|
||||
if default_value == _EMPTY_DEFAULT_VALUE:
|
||||
raise ValueError(f"Variable {full_key} not found")
|
||||
return default_value
|
||||
if len(values) > 1:
|
||||
raise ValueError(f"Multiple variables found for {full_key}")
|
||||
|
||||
variable = values[0]
|
||||
return self._convert_to_value_type(variable)
|
||||
|
||||
def save(self, variables_item: StorageVariables) -> None:
|
||||
"""Save variables to storage."""
|
||||
if variables_item.category == "secret":
|
||||
@@ -576,9 +640,11 @@ class StorageVariablesProvider(VariablesProvider):
|
||||
)
|
||||
if is_builtin:
|
||||
return builtin_variables
|
||||
executor_factory: Optional[
|
||||
DefaultExecutorFactory
|
||||
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
|
||||
executor_factory: Optional[DefaultExecutorFactory] = None
|
||||
if self.system_app:
|
||||
executor_factory = DefaultExecutorFactory.get_instance(
|
||||
self.system_app, default_component=None
|
||||
)
|
||||
if executor_factory:
|
||||
return await blocking_func_to_async(
|
||||
executor_factory.create(),
|
||||
|
Reference in New Issue
Block a user