mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
Native data AI application framework based on AWEL+AGENT (#1152)
Co-authored-by: Fangyin Cheng <staneyffer@gmail.com> Co-authored-by: lcx01800250 <lcx01800250@alibaba-inc.com> Co-authored-by: licunxing <864255598@qq.com> Co-authored-by: Aralhi <xiaoping0501@gmail.com> Co-authored-by: xuyuan23 <643854343@qq.com> Co-authored-by: aries_ckt <916701291@qq.com> Co-authored-by: hzh97 <2976151305@qq.com>
This commit is contained in:
@@ -9,9 +9,20 @@ from dbgpt.core.awel import (
|
||||
BaseOperator,
|
||||
BranchFunc,
|
||||
BranchOperator,
|
||||
CommonLLMHttpRequestBody,
|
||||
CommonLLMHttpResponseBody,
|
||||
DAGContext,
|
||||
JoinOperator,
|
||||
MapOperator,
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.core.awel.flow import (
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
Parameter,
|
||||
ViewMetadata,
|
||||
)
|
||||
from dbgpt.core.interface.llm import (
|
||||
LLMClient,
|
||||
@@ -20,6 +31,7 @@ from dbgpt.core.interface.llm import (
|
||||
ModelRequestContext,
|
||||
)
|
||||
from dbgpt.core.interface.message import ModelMessage
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
|
||||
RequestInput = Union[
|
||||
ModelRequest,
|
||||
@@ -31,9 +43,42 @@ RequestInput = Union[
|
||||
]
|
||||
|
||||
|
||||
class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest]):
|
||||
"""Build the model request from the input value."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Build Model Request",
|
||||
name="request_builder_operator",
|
||||
category=OperatorCategory.COMMON,
|
||||
description="Build the model request from the http request body.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Default Model Name",
|
||||
"model",
|
||||
str,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The model name of the model request.",
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Request Body",
|
||||
"input_value",
|
||||
CommonLLMHttpRequestBody,
|
||||
description="The input value of the operator.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Model Request",
|
||||
"output_value",
|
||||
ModelRequest,
|
||||
description="The output value of the operator.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, model: Optional[str] = None, **kwargs):
|
||||
"""Create a new request builder operator."""
|
||||
self._model = model
|
||||
@@ -90,6 +135,53 @@ class RequestBuilderOperator(MapOperator[RequestInput, ModelRequest], ABC):
|
||||
return ModelRequest(**req_dict)
|
||||
|
||||
|
||||
class MergedRequestBuilderOperator(JoinOperator[ModelRequest]):
|
||||
"""Build the model request from the input value."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Merge Model Request Messages",
|
||||
name="merged_request_builder_operator",
|
||||
category=OperatorCategory.COMMON,
|
||||
description="Merge the model request from the input value.",
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Model Request",
|
||||
"model_request",
|
||||
ModelRequest,
|
||||
description="The model request of upstream.",
|
||||
),
|
||||
IOField.build_from(
|
||||
"Model messages",
|
||||
"messages",
|
||||
ModelMessage,
|
||||
description="The model messages of upstream.",
|
||||
is_list=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Model Request",
|
||||
"output_value",
|
||||
ModelRequest,
|
||||
description="The output value of the operator.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Create a new request builder operator."""
|
||||
super().__init__(combine_function=self.merge_func, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
def merge_func(
|
||||
self, model_request: ModelRequest, messages: List[ModelMessage]
|
||||
) -> ModelRequest:
|
||||
"""Merge the model request with the messages."""
|
||||
model_request.messages = messages
|
||||
return model_request
|
||||
|
||||
|
||||
class BaseLLM:
|
||||
"""The abstract operator for a LLM."""
|
||||
|
||||
@@ -189,6 +281,54 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
the stream flag of the request.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="LLM Branch Operator",
|
||||
name="llm_branch_operator",
|
||||
category=OperatorCategory.LLM,
|
||||
operator_type=OperatorType.BRANCH,
|
||||
description="Branch the workflow based on the stream flag of the request.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Streaming Task Name",
|
||||
"stream_task_name",
|
||||
str,
|
||||
optional=True,
|
||||
default="streaming_llm_task",
|
||||
description="The name of the streaming task.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
"Non-Streaming Task Name",
|
||||
"no_stream_task_name",
|
||||
str,
|
||||
optional=True,
|
||||
default="llm_task",
|
||||
description="The name of the non-streaming task.",
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Model Request",
|
||||
"input_value",
|
||||
ModelRequest,
|
||||
description="The input value of the operator.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Streaming Model Request",
|
||||
"streaming_request",
|
||||
ModelRequest,
|
||||
description="The streaming request, to streaming Operator.",
|
||||
),
|
||||
IOField.build_from(
|
||||
"Non-Streaming Model Request",
|
||||
"no_streaming_request",
|
||||
ModelRequest,
|
||||
description="The non-streaming request, to non-streaming Operator.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, stream_task_name: str, no_stream_task_name: str, **kwargs):
|
||||
"""Create a new LLM branch operator.
|
||||
|
||||
@@ -226,3 +366,94 @@ class LLMBranchOperator(BranchOperator[ModelRequest, ModelRequest]):
|
||||
check_stream_true: self._stream_task_name,
|
||||
lambda x: not x.stream: self._no_stream_task_name,
|
||||
}
|
||||
|
||||
|
||||
class ModelOutput2CommonResponseOperator(
|
||||
MapOperator[ModelOutput, CommonLLMHttpResponseBody]
|
||||
):
|
||||
"""Map the model output to the common response body."""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Map Model Output to Common Response Body",
|
||||
name="model_output_2_common_response_body_operator",
|
||||
category=OperatorCategory.COMMON,
|
||||
description="Map the model output to the common response body.",
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Model Output",
|
||||
"input_value",
|
||||
ModelOutput,
|
||||
description="The input value of the operator.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Common Response Body",
|
||||
"output_value",
|
||||
CommonLLMHttpResponseBody,
|
||||
description="The output value of the operator.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def __int__(self, **kwargs):
|
||||
"""Create a new operator."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> CommonLLMHttpResponseBody:
|
||||
"""Map the model output to the common response body."""
|
||||
metrics = input_value.metrics.to_dict() if input_value.metrics else None
|
||||
return CommonLLMHttpResponseBody(
|
||||
text=input_value.text,
|
||||
error_code=input_value.error_code,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
|
||||
class CommonStreamingOutputOperator(TransformStreamAbsOperator[ModelOutput, str]):
|
||||
"""The Common Streaming Output Operator.
|
||||
|
||||
Transform model output to the string output to show in DB-GPT chat flow page.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Common Streaming Output Operator",
|
||||
name="common_streaming_output_operator",
|
||||
operator_type=OperatorType.TRANSFORM_STREAM,
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description="The common streaming LLM operator, for chat flow.",
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Upstream Model Output",
|
||||
"output_iter",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description="The model output of upstream.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Model Output",
|
||||
"model_output",
|
||||
str,
|
||||
is_list=True,
|
||||
description="The model output after transform to common stream format",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def transform_stream(self, output_iter: AsyncIterator[ModelOutput]):
|
||||
"""Transform upstream output iter to string foramt."""
|
||||
async for model_output in output_iter:
|
||||
if model_output.error_code != 0:
|
||||
error_msg = (
|
||||
f"[ERROR](error_code: {model_output.error_code}): "
|
||||
f"{model_output.text}"
|
||||
)
|
||||
yield f"data:{error_msg}"
|
||||
return
|
||||
decoded_unicode = model_output.text.replace("\ufffd", "")
|
||||
msg = decoded_unicode.replace("\n", "\\n")
|
||||
yield f"data:{msg}\n\n"
|
||||
|
@@ -1,18 +1,22 @@
|
||||
"""The message operator."""
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import (
|
||||
InMemoryStorage,
|
||||
LLMClient,
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
StorageConversation,
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, Parameter, ViewMetadata
|
||||
from dbgpt.core.interface.message import (
|
||||
BaseMessage,
|
||||
_messages_to_str,
|
||||
@@ -20,6 +24,8 @@ from dbgpt.core.interface.message import (
|
||||
_split_messages_by_round,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
"""Base class for conversation operators."""
|
||||
@@ -96,7 +102,7 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
raise ValueError(f"Message role {message.role} is not supported")
|
||||
|
||||
|
||||
ChatHistoryLoadType = Union[ModelRequestContext, Dict[str, Any]]
|
||||
ChatHistoryLoadType = Union[ModelRequest, ModelRequestContext, Dict[str, Any]]
|
||||
|
||||
|
||||
class PreChatHistoryLoadOperator(
|
||||
@@ -111,6 +117,50 @@ class PreChatHistoryLoadOperator(
|
||||
This operator just load the conversation and messages from storage.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Chat History Load Operator",
|
||||
name="chat_history_load_operator",
|
||||
category=OperatorCategory.CONVERSION,
|
||||
description="The operator to load chat history from storage.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label="Conversation Storage",
|
||||
name="storage",
|
||||
type=StorageInterface,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The conversation storage, store the conversation items("
|
||||
"Not include message items). If None, we will use InMemoryStorage.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Message Storage",
|
||||
name="message_storage",
|
||||
type=StorageInterface,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The message storage, store the messages of one "
|
||||
"conversation. If None, we will use InMemoryStorage.",
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
label="Model Request",
|
||||
name="input_value",
|
||||
type=ModelRequest,
|
||||
description="The model request.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
label="Stored Messages",
|
||||
name="output_value",
|
||||
type=BaseMessage,
|
||||
description="The messages stored in the storage.",
|
||||
is_list=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
@@ -119,6 +169,17 @@ class PreChatHistoryLoadOperator(
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new PreChatHistoryLoadOperator."""
|
||||
if not storage:
|
||||
logger.info(
|
||||
"Storage is not set, use the InMemoryStorage as the conversation "
|
||||
"storage."
|
||||
)
|
||||
storage = InMemoryStorage()
|
||||
if not message_storage:
|
||||
logger.info(
|
||||
"Message storage is not set, use the InMemoryStorage as the message "
|
||||
)
|
||||
message_storage = InMemoryStorage()
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._include_system_message = include_system_message
|
||||
@@ -136,6 +197,11 @@ class PreChatHistoryLoadOperator(
|
||||
raise ValueError("Model request context can't be None")
|
||||
if isinstance(input_value, dict):
|
||||
input_value = ModelRequestContext(**input_value)
|
||||
elif isinstance(input_value, ModelRequest):
|
||||
if not input_value.context:
|
||||
raise ValueError("Model request context can't be None")
|
||||
input_value = input_value.context
|
||||
input_value = cast(ModelRequestContext, input_value)
|
||||
if not input_value.conv_uid:
|
||||
input_value.conv_uid = str(uuid.uuid4())
|
||||
if not input_value.extra:
|
||||
|
@@ -2,27 +2,94 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import root_validator
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
StorageConversation,
|
||||
)
|
||||
from dbgpt.core.awel import JoinOperator, MapOperator
|
||||
from dbgpt.core.awel.flow import (
|
||||
IOField,
|
||||
OperatorCategory,
|
||||
OperatorType,
|
||||
Parameter,
|
||||
ResourceCategory,
|
||||
ViewMetadata,
|
||||
register_resource,
|
||||
)
|
||||
from dbgpt.core.interface.message import BaseMessage
|
||||
from dbgpt.core.interface.operators.llm_operator import BaseLLM
|
||||
from dbgpt.core.interface.operators.message_operator import BaseConversationOperator
|
||||
from dbgpt.core.interface.prompt import (
|
||||
BaseChatPromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
MessageType,
|
||||
PromptTemplate,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.util.function_utils import rearrange_args_by_type
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="Common Chat Prompt Template",
|
||||
name="common_chat_prompt_template",
|
||||
category=ResourceCategory.PROMPT,
|
||||
description="The operator to build the prompt with static prompt.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
label="System Message",
|
||||
name="system_message",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="You are a helpful AI Assistant.",
|
||||
description="The system message.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Message placeholder",
|
||||
name="message_placeholder",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="chat_history",
|
||||
description="The chat history message placeholder.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
label="Human Message",
|
||||
name="human_message",
|
||||
type=str,
|
||||
optional=True,
|
||||
default="{user_input}",
|
||||
placeholder="{user_input}",
|
||||
description="The human message.",
|
||||
),
|
||||
],
|
||||
)
|
||||
class CommonChatPromptTemplate(ChatPromptTemplate):
|
||||
"""The common chat prompt template."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre fill the messages."""
|
||||
if "system_message" not in values:
|
||||
raise ValueError("No system message")
|
||||
if "human_message" not in values:
|
||||
raise ValueError("No human message")
|
||||
if "message_placeholder" not in values:
|
||||
raise ValueError("No message placeholder")
|
||||
system_message = values.pop("system_message")
|
||||
human_message = values.pop("human_message")
|
||||
message_placeholder = values.pop("message_placeholder")
|
||||
values["messages"] = [
|
||||
SystemPromptTemplate.from_template(system_message),
|
||||
MessagesPlaceholder(variable_name=message_placeholder),
|
||||
HumanPromptTemplate.from_template(human_message),
|
||||
]
|
||||
return cls.base_pre_fill(values)
|
||||
|
||||
|
||||
class BasePromptBuilderOperator(BaseConversationOperator, ABC):
|
||||
"""The base prompt builder operator."""
|
||||
|
||||
@@ -183,6 +250,38 @@ class PromptBuilderOperator(
|
||||
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Prompt Builder Operator",
|
||||
name="prompt_builder_operator",
|
||||
description="Build messages from prompt template.",
|
||||
category=OperatorCategory.COMMON,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Chat Prompt Template",
|
||||
"prompt",
|
||||
ChatPromptTemplate,
|
||||
description="The chat prompt template.",
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Prompt Input Dict",
|
||||
"prompt_input_dict",
|
||||
dict,
|
||||
description="The prompt dict.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Formatted Messages",
|
||||
"formatted_messages",
|
||||
ModelMessage,
|
||||
is_list=True,
|
||||
description="The formatted messages.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, prompt: PromptTemplateType, **kwargs):
|
||||
"""Create a new prompt builder operator."""
|
||||
if isinstance(prompt, str):
|
||||
@@ -237,6 +336,62 @@ class HistoryPromptBuilderOperator(
|
||||
The prompt will pass to this operator.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="History Prompt Builder Operator",
|
||||
name="history_prompt_builder_operator",
|
||||
description="Build messages from prompt template and chat history.",
|
||||
operator_type=OperatorType.JOIN,
|
||||
category=OperatorCategory.CONVERSION,
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Chat Prompt Template",
|
||||
"prompt",
|
||||
ChatPromptTemplate,
|
||||
description="The chat prompt template.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
"History Key",
|
||||
"history_key",
|
||||
str,
|
||||
optional=True,
|
||||
default="chat_history",
|
||||
description="The key of history in prompt dict.",
|
||||
),
|
||||
Parameter.build_from(
|
||||
"String History",
|
||||
"str_history",
|
||||
bool,
|
||||
optional=True,
|
||||
default=False,
|
||||
description="Whether to convert the history to string.",
|
||||
),
|
||||
],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"History",
|
||||
"history",
|
||||
BaseMessage,
|
||||
is_list=True,
|
||||
description="The history.",
|
||||
),
|
||||
IOField.build_from(
|
||||
"Prompt Input Dict",
|
||||
"prompt_input_dict",
|
||||
dict,
|
||||
description="The prompt dict.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Formatted Messages",
|
||||
"formatted_messages",
|
||||
ModelMessage,
|
||||
is_list=True,
|
||||
description="The formatted messages.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt: ChatPromptTemplate,
|
||||
|
@@ -13,6 +13,7 @@ from typing import Any, TypeVar, Union
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.awel import MapOperator
|
||||
from dbgpt.core.awel.flow import IOField, OperatorCategory, OperatorType, ViewMetadata
|
||||
|
||||
T = TypeVar("T")
|
||||
ResponseTye = Union[str, bytes, ModelOutput]
|
||||
@@ -26,6 +27,33 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
Output parsers help structure language model responses.
|
||||
"""
|
||||
|
||||
metadata = ViewMetadata(
|
||||
label="Base Output Operator",
|
||||
name="base_output_operator",
|
||||
operator_type=OperatorType.TRANSFORM_STREAM,
|
||||
category=OperatorCategory.OUTPUT_PARSER,
|
||||
description="The base LLM out parse.",
|
||||
parameters=[],
|
||||
inputs=[
|
||||
IOField.build_from(
|
||||
"Model Output",
|
||||
"model_output",
|
||||
ModelOutput,
|
||||
is_list=True,
|
||||
description="The model output of upstream.",
|
||||
)
|
||||
],
|
||||
outputs=[
|
||||
IOField.build_from(
|
||||
"Model Output",
|
||||
"model_output",
|
||||
str,
|
||||
is_list=True,
|
||||
description="The model output after transform to openai stream format",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
def __init__(self, is_stream_out: bool = True, **kwargs):
|
||||
"""Create a new output parser."""
|
||||
super().__init__(**kwargs)
|
||||
|
@@ -240,7 +240,7 @@ class ChatPromptTemplate(BasePromptTemplate):
|
||||
return result_messages
|
||||
|
||||
@root_validator(pre=True)
|
||||
def pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def base_pre_fill(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Pre-fill the messages."""
|
||||
input_variables = values.get("input_variables", {})
|
||||
messages = values.get("messages", [])
|
||||
|
@@ -3,6 +3,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, cast
|
||||
|
||||
from dbgpt.core.awel.flow import Parameter, ResourceCategory, register_resource
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
@@ -384,6 +385,23 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
)
|
||||
|
||||
|
||||
@register_resource(
|
||||
label="Memory Storage",
|
||||
name="in_memory_storage",
|
||||
category=ResourceCategory.STORAGE,
|
||||
description="Save your data in memory.",
|
||||
parameters=[
|
||||
Parameter.build_from(
|
||||
"Serializer",
|
||||
"serializer",
|
||||
Serializer,
|
||||
optional=True,
|
||||
default=None,
|
||||
description="The serializer for serializing the data. If not set, the "
|
||||
"default JSON serializer will be used.",
|
||||
)
|
||||
],
|
||||
)
|
||||
@PublicAPI(stability="alpha")
|
||||
class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""The in-memory storage for storing and loading data."""
|
||||
|
Reference in New Issue
Block a user