feat(core): Support higher-order operators (#1984)

Co-authored-by: 谨欣 <echo.cmy@antgroup.com>
This commit is contained in:
Fangyin Cheng
2024-09-09 10:15:37 +08:00
committed by GitHub
parent f6d5fc4595
commit 65c875db20
62 changed files with 6281 additions and 386 deletions

View File

@@ -195,6 +195,9 @@ class ModelRequest:
temperature: Optional[float] = None
"""The temperature of the model inference."""
top_p: Optional[float] = None
"""The top p of the model inference."""
max_new_tokens: Optional[int] = None
"""The maximum number of tokens to generate."""

View File

@@ -317,6 +317,25 @@ class ModelMessage(BaseModel):
"""
return _messages_to_str(messages, human_prefix, ai_prefix, system_prefix)
@staticmethod
def parse_user_message(messages: List[ModelMessage]) -> str:
"""Parse user message from messages.
Args:
messages (List[ModelMessage]): The all messages in the conversation.
Returns:
str: The user message
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
return lass_user_message
_SingleRoundMessage = List[BaseMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[BaseMessage]]
@@ -1244,9 +1263,11 @@ def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
content=ai_message.content,
index=ai_message.index,
round_index=ai_message.round_index,
additional_kwargs=ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {},
additional_kwargs=(
ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {}
),
)
current_round.append(view_message)
return sum(messages_by_round, [])

View File

@@ -246,10 +246,16 @@ class BaseLLM:
SHARE_DATA_KEY_MODEL_NAME = "share_data_key_model_name"
SHARE_DATA_KEY_MODEL_OUTPUT = "share_data_key_model_output"
SHARE_DATA_KEY_MODEL_OUTPUT_VIEW = "share_data_key_model_output_view"
def __init__(self, llm_client: Optional[LLMClient] = None):
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
):
"""Create a new LLM operator."""
self._llm_client = llm_client
self._save_model_output = save_model_output
@property
def llm_client(self) -> LLMClient:
@@ -262,9 +268,10 @@ class BaseLLM:
self, current_dag_context: DAGContext, model_output: ModelOutput
) -> None:
"""Save the model output to the share data."""
await current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
if self._save_model_output:
await current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_MODEL_OUTPUT, model_output
)
class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
@@ -276,9 +283,14 @@ class BaseLLMOperator(BaseLLM, MapOperator[ModelRequest, ModelOutput], ABC):
This operator will generate a no streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
"""Create a new LLM operator."""
super().__init__(llm_client=llm_client)
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
MapOperator.__init__(self, **kwargs)
async def map(self, request: ModelRequest) -> ModelOutput:
@@ -309,13 +321,18 @@ class BaseStreamingLLMOperator(
This operator will generate streaming response.
"""
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
def __init__(
self,
llm_client: Optional[LLMClient] = None,
save_model_output: bool = True,
**kwargs,
):
"""Create a streaming operator for a LLM.
Args:
llm_client (LLMClient, optional): The LLM client. Defaults to None.
"""
super().__init__(llm_client=llm_client)
super().__init__(llm_client=llm_client, save_model_output=save_model_output)
BaseOperator.__init__(self, **kwargs)
async def streamify( # type: ignore

View File

@@ -4,14 +4,10 @@ from abc import ABC
from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import model_validator
from dbgpt.core import (
ModelMessage,
ModelMessageRoleType,
ModelOutput,
StorageConversation,
)
from dbgpt.core import ModelMessage, ModelOutput, StorageConversation
from dbgpt.core.awel import JoinOperator, MapOperator
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
@@ -42,6 +38,7 @@ from dbgpt.util.i18n_utils import _
name="common_chat_prompt_template",
category=ResourceCategory.PROMPT,
description=_("The operator to build the prompt with static prompt."),
tags={"order": TAGS_ORDER_HIGH},
parameters=[
Parameter.build_from(
label=_("System Message"),
@@ -101,9 +98,10 @@ class CommonChatPromptTemplate(ChatPromptTemplate):
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
"""The base prompt builder operator."""
def __init__(self, check_storage: bool, **kwargs):
def __init__(self, check_storage: bool, save_to_storage: bool = True, **kwargs):
"""Create a new prompt builder operator."""
super().__init__(check_storage=check_storage, **kwargs)
self._save_to_storage = save_to_storage
async def format_prompt(
self, prompt: ChatPromptTemplate, prompt_dict: Dict[str, Any]
@@ -122,8 +120,9 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
pass_kwargs = {k: v for k, v in kwargs.items() if k in prompt.input_variables}
messages = prompt.format_messages(**pass_kwargs)
model_messages = ModelMessage.from_base_messages(messages)
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(model_messages)
if self._save_to_storage:
# Start new round conversation, and save user message to storage
await self.start_new_round_conv(model_messages)
return model_messages
async def start_new_round_conv(self, messages: List[ModelMessage]) -> None:
@@ -132,13 +131,7 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
Args:
messages (List[ModelMessage]): The messages.
"""
lass_user_message = None
for message in messages[::-1]:
if message.role == ModelMessageRoleType.HUMAN:
lass_user_message = message.content
break
if not lass_user_message:
raise ValueError("No user message")
lass_user_message = ModelMessage.parse_user_message(messages)
storage_conv: Optional[
StorageConversation
] = await self.get_storage_conversation()
@@ -150,6 +143,8 @@ class BasePromptBuilderOperator(BaseConversationOperator, ABC):
async def after_dag_end(self, event_loop_task_id: int):
"""Execute after the DAG finished."""
if not self._save_to_storage:
return
# Save the storage conversation to storage after the whole DAG finished
storage_conv: Optional[
StorageConversation
@@ -422,7 +417,7 @@ class HistoryPromptBuilderOperator(
self._prompt = prompt
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
@@ -455,7 +450,7 @@ class HistoryDynamicPromptBuilderOperator(
"""Create a new history dynamic prompt builder operator."""
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
BasePromptBuilderOperator.__init__(self, check_storage=check_storage, **kwargs)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type

View File

@@ -13,7 +13,13 @@ from typing import Any, TypeVar, Union
from dbgpt.core import ModelOutput
from dbgpt.core.awel import MapOperator
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OperatorType,
ViewMetadata,
)
from dbgpt.util.i18n_utils import _
T = TypeVar("T")
@@ -271,7 +277,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
if self.current_dag_context.streaming_call:
return self.parse_model_stream_resp_ex(input_value, 0)
else:
return self.parse_model_nostream_resp(input_value, "###")
return self.parse_model_nostream_resp(input_value, "#####################")
def _parse_model_response(response: ResponseTye):
@@ -293,6 +299,31 @@ def _parse_model_response(response: ResponseTye):
class SQLOutputParser(BaseOutputParser):
"""Parse the SQL output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL Output Parser"),
name="default_sql_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_("Parse the SQL output of an LLM call."),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("Dict SQL Output"),
"dict",
dict,
description=_("The dict output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
@@ -302,3 +333,57 @@ class SQLOutputParser(BaseOutputParser):
model_out_text = super().parse_model_nostream_resp(response, sep)
clean_str = super().parse_prompt_response(model_out_text)
return json.loads(clean_str, strict=True)
class SQLListOutputParser(BaseOutputParser):
"""Parse the SQL list output of an LLM call."""
metadata = ViewMetadata(
label=_("SQL List Output Parser"),
name="default_sql_list_output_parser",
category=OperatorCategory.OUTPUT_PARSER,
description=_(
"Parse the SQL list output of an LLM call, mostly used for dashboard."
),
parameters=[],
inputs=[
IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output of upstream."),
)
],
outputs=[
IOField.build_from(
_("List SQL Output"),
"list",
dict,
is_list=True,
description=_("The list output after parsing."),
)
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, is_stream_out: bool = False, **kwargs):
"""Create a new SQL list output parser."""
super().__init__(is_stream_out=is_stream_out, **kwargs)
def parse_model_nostream_resp(self, response: ResponseTye, sep: str):
"""Parse the output of an LLM call."""
from dbgpt.util.json_utils import find_json_objects
model_out_text = super().parse_model_nostream_resp(response, sep)
json_objects = find_json_objects(model_out_text)
json_count = len(json_objects)
if json_count < 1:
raise ValueError("Unable to obtain valid output.")
parsed_json_list = json_objects[0]
if not isinstance(parsed_json_list, list):
if isinstance(parsed_json_list, dict):
return [parsed_json_list]
else:
raise ValueError("Invalid output format.")
return parsed_json_list

View File

@@ -254,6 +254,18 @@ class ChatPromptTemplate(BasePromptTemplate):
values["input_variables"] = sorted(input_variables)
return values
def get_placeholders(self) -> List[str]:
"""Get all placeholders in the prompt template.
Returns:
List[str]: The placeholders.
"""
placeholders = set()
for message in self.messages:
if isinstance(message, MessagesPlaceholder):
placeholders.add(message.variable_name)
return sorted(placeholders)
@dataclasses.dataclass
class PromptTemplateIdentifier(ResourceIdentifier):

View File

@@ -31,6 +31,7 @@ BUILTIN_VARIABLES_CORE_VARIABLES = "dbgpt.core.variables"
BUILTIN_VARIABLES_CORE_SECRETS = "dbgpt.core.secrets"
BUILTIN_VARIABLES_CORE_LLMS = "dbgpt.core.model.llms"
BUILTIN_VARIABLES_CORE_EMBEDDINGS = "dbgpt.core.model.embeddings"
# Not implemented yet
BUILTIN_VARIABLES_CORE_RERANKERS = "dbgpt.core.model.rerankers"
BUILTIN_VARIABLES_CORE_DATASOURCES = "dbgpt.core.datasources"
BUILTIN_VARIABLES_CORE_AGENTS = "dbgpt.core.agent.agents"
@@ -373,6 +374,15 @@ class VariablesProvider(BaseComponent, ABC):
) -> Any:
"""Query variables from storage."""
async def async_get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage async."""
raise NotImplementedError("Current variables provider does not support async.")
@abstractmethod
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
@@ -456,6 +466,24 @@ class VariablesPlaceHolder:
return None
raise e
async def async_parse(
self,
variables_provider: VariablesProvider,
ignore_not_found_error: bool = False,
default_identifier_map: Optional[Dict[str, str]] = None,
):
"""Parse the variables async."""
try:
return await variables_provider.async_get(
self.full_key,
self.default_value,
default_identifier_map=default_identifier_map,
)
except ValueError as e:
if ignore_not_found_error:
return None
raise e
def __repr__(self):
"""Return the representation of the variables place holder."""
return f"<VariablesPlaceHolder " f"{self.param_name} {self.full_key}>"
@@ -507,6 +535,42 @@ class StorageVariablesProvider(VariablesProvider):
variable.value = self.encryption.decrypt(variable.value, variable.salt)
return self._convert_to_value_type(variable)
async def async_get(
self,
full_key: str,
default_value: Optional[str] = _EMPTY_DEFAULT_VALUE,
default_identifier_map: Optional[Dict[str, str]] = None,
) -> Any:
"""Query variables from storage async."""
# Try to get variables from storage
value = await blocking_func_to_async_no_executor(
self.get,
full_key,
default_value=None,
default_identifier_map=default_identifier_map,
)
if value is not None:
return value
key = VariablesIdentifier.from_str_identifier(full_key, default_identifier_map)
# Get all builtin variables
variables = await self.async_get_variables(
key=key.key,
scope=key.scope,
scope_key=key.scope_key,
sys_code=key.sys_code,
user_name=key.user_name,
)
values = [v for v in variables if v.name == key.name]
if not values:
if default_value == _EMPTY_DEFAULT_VALUE:
raise ValueError(f"Variable {full_key} not found")
return default_value
if len(values) > 1:
raise ValueError(f"Multiple variables found for {full_key}")
variable = values[0]
return self._convert_to_value_type(variable)
def save(self, variables_item: StorageVariables) -> None:
"""Save variables to storage."""
if variables_item.category == "secret":
@@ -576,9 +640,11 @@ class StorageVariablesProvider(VariablesProvider):
)
if is_builtin:
return builtin_variables
executor_factory: Optional[
DefaultExecutorFactory
] = DefaultExecutorFactory.get_instance(self.system_app, default_component=None)
executor_factory: Optional[DefaultExecutorFactory] = None
if self.system_app:
executor_factory = DefaultExecutorFactory.get_instance(
self.system_app, default_component=None
)
if executor_factory:
return await blocking_func_to_async(
executor_factory.create(),