mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-28 14:27:20 +00:00
444 lines
15 KiB
Python
444 lines
15 KiB
Python
from typing import List, Literal, Optional, Tuple, Union
|
|
|
|
from dbgpt._private.pydantic import BaseModel, Field
|
|
from dbgpt.core import (
|
|
BaseMessage,
|
|
ChatPromptTemplate,
|
|
LLMClient,
|
|
ModelOutput,
|
|
ModelRequest,
|
|
StorageConversation,
|
|
)
|
|
from dbgpt.core.awel import (
|
|
DAG,
|
|
BaseOperator,
|
|
CommonLLMHttpRequestBody,
|
|
DAGContext,
|
|
DefaultInputContext,
|
|
InputOperator,
|
|
JoinOperator,
|
|
MapOperator,
|
|
SimpleCallDataInputSource,
|
|
TaskOutput,
|
|
)
|
|
from dbgpt.core.awel.flow import (
|
|
TAGS_ORDER_HIGH,
|
|
IOField,
|
|
OperatorCategory,
|
|
OptionValue,
|
|
Parameter,
|
|
ViewMetadata,
|
|
ui,
|
|
)
|
|
from dbgpt.core.interface.operators.message_operator import (
|
|
BaseConversationOperator,
|
|
BufferedConversationMapperOperator,
|
|
TokenBufferedConversationMapperOperator,
|
|
)
|
|
from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator
|
|
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
|
|
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
|
from dbgpt.util.i18n_utils import _
|
|
from dbgpt.util.tracer import root_tracer
|
|
|
|
|
|
class HOContextBody(BaseModel):
|
|
"""Higher-order context body."""
|
|
|
|
context_key: str = Field(
|
|
"context",
|
|
description=_("The context key can be used as the key for formatting prompt."),
|
|
)
|
|
context: Union[str, List[str]] = Field(
|
|
...,
|
|
description=_("The context."),
|
|
)
|
|
|
|
|
|
class BaseHOLLMOperator(
|
|
BaseConversationOperator,
|
|
JoinOperator[ModelRequest],
|
|
LLMOperator,
|
|
StreamingLLMOperator,
|
|
):
|
|
"""Higher-order model request builder operator."""
|
|
|
|
def __init__(
|
|
self,
|
|
prompt_template: ChatPromptTemplate,
|
|
model: str = None,
|
|
llm_client: Optional[LLMClient] = None,
|
|
history_merge_mode: Literal["none", "window", "token"] = "window",
|
|
user_message_key: str = "user_input",
|
|
history_key: Optional[str] = None,
|
|
keep_start_rounds: Optional[int] = None,
|
|
keep_end_rounds: Optional[int] = None,
|
|
max_token_limit: int = 2048,
|
|
**kwargs,
|
|
):
|
|
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
|
|
LLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
|
StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
|
|
|
|
# User must select a history merge mode
|
|
self._history_merge_mode = history_merge_mode
|
|
self._user_message_key = user_message_key
|
|
self._has_history = history_merge_mode != "none"
|
|
self._prompt_template = prompt_template
|
|
self._model = model
|
|
self._history_key = history_key
|
|
self._str_history = False
|
|
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
|
|
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
|
|
self._max_token_limit = max_token_limit
|
|
self._sub_compose_dag = self._build_conversation_composer_dag()
|
|
|
|
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
|
|
conv_serve = ConversationServe.get_instance(self.system_app)
|
|
self._storage = conv_serve.conv_storage
|
|
self._message_storage = conv_serve.message_storage
|
|
|
|
_: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx)
|
|
dag_ctx.current_task_context.set_task_input(
|
|
DefaultInputContext([dag_ctx.current_task_context])
|
|
)
|
|
if dag_ctx.streaming_call:
|
|
task_output = await StreamingLLMOperator._do_run(self, dag_ctx)
|
|
else:
|
|
task_output = await LLMOperator._do_run(self, dag_ctx)
|
|
|
|
return task_output
|
|
|
|
async def after_dag_end(self, event_loop_task_id: int):
|
|
model_output: Optional[
|
|
ModelOutput
|
|
] = await self.current_dag_context.get_from_share_data(
|
|
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT
|
|
)
|
|
model_output_view: Optional[
|
|
str
|
|
] = await self.current_dag_context.get_from_share_data(
|
|
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW
|
|
)
|
|
storage_conv = await self.get_storage_conversation()
|
|
end_current_round: bool = False
|
|
if model_output and storage_conv:
|
|
# Save model output message to storage
|
|
storage_conv.add_ai_message(model_output.text)
|
|
end_current_round = True
|
|
if model_output_view and storage_conv:
|
|
# Save model output view to storage
|
|
storage_conv.add_view_message(model_output_view)
|
|
end_current_round = True
|
|
if end_current_round:
|
|
# End current conversation round and flush to storage
|
|
storage_conv.end_current_round()
|
|
|
|
async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
|
|
dynamic_inputs = []
|
|
for arg in args:
|
|
if isinstance(arg, HOContextBody):
|
|
dynamic_inputs.append(arg)
|
|
# Load and store chat history, default use InMemoryStorage.
|
|
storage_conv, history_messages = await self.blocking_func_to_async(
|
|
self._build_storage, req
|
|
)
|
|
# Save the storage conversation to share data, for the child operators
|
|
await self.current_dag_context.save_to_share_data(
|
|
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
|
|
)
|
|
|
|
user_input = (
|
|
req.messages[-1] if isinstance(req.messages, list) else req.messages
|
|
)
|
|
prompt_dict = {
|
|
self._user_message_key: user_input,
|
|
}
|
|
for dynamic_input in dynamic_inputs:
|
|
if dynamic_input.context_key in prompt_dict:
|
|
raise ValueError(
|
|
f"Duplicate context key '{dynamic_input.context_key}' in upstream "
|
|
f"operators."
|
|
)
|
|
prompt_dict[dynamic_input.context_key] = dynamic_input.context
|
|
|
|
call_data = {
|
|
"messages": history_messages,
|
|
"prompt_dict": prompt_dict,
|
|
}
|
|
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
|
|
# Sub dag, use the same dag context in the parent dag
|
|
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
|
|
model_request = ModelRequest.build_request(
|
|
model=req.model,
|
|
messages=messages,
|
|
context=req.context,
|
|
temperature=req.temperature,
|
|
max_new_tokens=req.max_new_tokens,
|
|
span_id=root_tracer.get_current_span_id(),
|
|
echo=False,
|
|
)
|
|
if storage_conv:
|
|
# Start new round
|
|
storage_conv.start_new_round()
|
|
storage_conv.add_user_message(user_input)
|
|
return model_request
|
|
|
|
def _build_storage(
|
|
self, req: CommonLLMHttpRequestBody
|
|
) -> Tuple[StorageConversation, List[BaseMessage]]:
|
|
# Create a new storage conversation, this will load the conversation from
|
|
# storage, so we must do this async
|
|
storage_conv: StorageConversation = StorageConversation(
|
|
conv_uid=req.conv_uid,
|
|
chat_mode=req.chat_mode,
|
|
user_name=req.user_name,
|
|
sys_code=req.sys_code,
|
|
conv_storage=self.storage,
|
|
message_storage=self.message_storage,
|
|
param_type="",
|
|
param_value=req.chat_param,
|
|
)
|
|
# Get history messages from storage
|
|
history_messages: List[BaseMessage] = storage_conv.get_history_message(
|
|
include_system_message=False
|
|
)
|
|
|
|
return storage_conv, history_messages
|
|
|
|
def _build_conversation_composer_dag(self) -> DAG:
|
|
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
|
|
input_task = InputOperator(input_source=SimpleCallDataInputSource())
|
|
# History transform task
|
|
if self._history_merge_mode == "token":
|
|
history_transform_task = TokenBufferedConversationMapperOperator(
|
|
model=self._model,
|
|
llm_client=self.llm_client,
|
|
max_token_limit=self._max_token_limit,
|
|
)
|
|
else:
|
|
history_transform_task = BufferedConversationMapperOperator(
|
|
keep_start_rounds=self._keep_start_rounds,
|
|
keep_end_rounds=self._keep_end_rounds,
|
|
)
|
|
if self._history_key:
|
|
history_key = self._history_key
|
|
else:
|
|
placeholders = self._prompt_template.get_placeholders()
|
|
if not placeholders or len(placeholders) != 1:
|
|
raise ValueError(
|
|
"The prompt template must have exactly one placeholder if "
|
|
"history_key is not provided."
|
|
)
|
|
history_key = placeholders[0]
|
|
history_prompt_build_task = HistoryPromptBuilderOperator(
|
|
prompt=self._prompt_template,
|
|
history_key=history_key,
|
|
check_storage=False,
|
|
save_to_storage=False,
|
|
str_history=self._str_history,
|
|
)
|
|
# Build composer dag
|
|
(
|
|
input_task
|
|
>> MapOperator(lambda x: x["messages"])
|
|
>> history_transform_task
|
|
>> history_prompt_build_task
|
|
)
|
|
(
|
|
input_task
|
|
>> MapOperator(lambda x: x["prompt_dict"])
|
|
>> history_prompt_build_task
|
|
)
|
|
|
|
return composer_dag
|
|
|
|
|
|
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
|
|
_("Prompt Template"),
|
|
"prompt_template",
|
|
ChatPromptTemplate,
|
|
description=_("The prompt template for the conversation."),
|
|
)
|
|
_PARAMETER_MODEL = Parameter.build_from(
|
|
_("Model Name"),
|
|
"model",
|
|
str,
|
|
optional=True,
|
|
default=None,
|
|
description=_("The model name."),
|
|
)
|
|
|
|
_PARAMETER_LLM_CLIENT = Parameter.build_from(
|
|
_("LLM Client"),
|
|
"llm_client",
|
|
LLMClient,
|
|
optional=True,
|
|
default=None,
|
|
description=_(
|
|
"The LLM Client, how to connect to the LLM model, if not provided, it will use"
|
|
" the default client deployed by DB-GPT."
|
|
),
|
|
)
|
|
_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from(
|
|
_("History Message Merge Mode"),
|
|
"history_merge_mode",
|
|
str,
|
|
optional=True,
|
|
default="none",
|
|
options=[
|
|
OptionValue(label="No History", name="none", value="none"),
|
|
OptionValue(label="Message Window", name="window", value="window"),
|
|
OptionValue(label="Token Length", name="token", value="token"),
|
|
],
|
|
description=_(
|
|
"The history merge mode, supports 'none', 'window' and 'token'."
|
|
" 'none': no history merge, 'window': merge by conversation window, 'token': "
|
|
"merge by token length."
|
|
),
|
|
ui=ui.UISelect(),
|
|
)
|
|
_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from(
|
|
_("User Message Key"),
|
|
"user_message_key",
|
|
str,
|
|
optional=True,
|
|
default="user_input",
|
|
description=_(
|
|
"The key of the user message in your prompt, default is 'user_input'."
|
|
),
|
|
)
|
|
_PARAMETER_HISTORY_KEY = Parameter.build_from(
|
|
_("History Key"),
|
|
"history_key",
|
|
str,
|
|
optional=True,
|
|
default=None,
|
|
description=_(
|
|
"The chat history key, with chat history message pass to prompt template, "
|
|
"if not provided, it will parse the prompt template to get the key."
|
|
),
|
|
)
|
|
_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from(
|
|
_("Keep Start Rounds"),
|
|
"keep_start_rounds",
|
|
int,
|
|
optional=True,
|
|
default=None,
|
|
description=_("The start rounds to keep in the chat history."),
|
|
)
|
|
_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from(
|
|
_("Keep End Rounds"),
|
|
"keep_end_rounds",
|
|
int,
|
|
optional=True,
|
|
default=None,
|
|
description=_("The end rounds to keep in the chat history."),
|
|
)
|
|
_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from(
|
|
_("Max Token Limit"),
|
|
"max_token_limit",
|
|
int,
|
|
optional=True,
|
|
default=2048,
|
|
description=_("The max token limit to keep in the chat history."),
|
|
)
|
|
|
|
_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from(
|
|
_("Common LLM Request Body"),
|
|
"common_llm_request_body",
|
|
CommonLLMHttpRequestBody,
|
|
_("The common LLM request body."),
|
|
)
|
|
_INPUTS_EXTRA_CONTEXT = IOField.build_from(
|
|
_("Extra Context"),
|
|
"extra_context",
|
|
HOContextBody,
|
|
_(
|
|
"Extra context for building prompt(Knowledge context, database "
|
|
"schema, etc), you can add multiple context."
|
|
),
|
|
dynamic=True,
|
|
)
|
|
_OUTPUTS_MODEL_OUTPUT = IOField.build_from(
|
|
_("Model Output"),
|
|
"model_output",
|
|
ModelOutput,
|
|
description=_("The model output."),
|
|
)
|
|
_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from(
|
|
_("Streaming Model Output"),
|
|
"streaming_model_output",
|
|
ModelOutput,
|
|
is_list=True,
|
|
description=_("The streaming model output."),
|
|
)
|
|
|
|
|
|
class HOLLMOperator(BaseHOLLMOperator):
|
|
metadata = ViewMetadata(
|
|
label=_("LLM Operator"),
|
|
name="higher_order_llm_operator",
|
|
category=OperatorCategory.LLM,
|
|
description=_(
|
|
"High-level LLM operator, supports multi-round conversation "
|
|
"(conversation window, token length and no multi-round)."
|
|
),
|
|
parameters=[
|
|
_PARAMETER_PROMPT_TEMPLATE.new(),
|
|
_PARAMETER_MODEL.new(),
|
|
_PARAMETER_LLM_CLIENT.new(),
|
|
_PARAMETER_HISTORY_MERGE_MODE.new(),
|
|
_PARAMETER_USER_MESSAGE_KEY.new(),
|
|
_PARAMETER_HISTORY_KEY.new(),
|
|
_PARAMETER_KEEP_START_ROUNDS.new(),
|
|
_PARAMETER_KEEP_END_ROUNDS.new(),
|
|
_PARAMETER_MAX_TOKEN_LIMIT.new(),
|
|
],
|
|
inputs=[
|
|
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
|
|
_INPUTS_EXTRA_CONTEXT.new(),
|
|
],
|
|
outputs=[
|
|
_OUTPUTS_MODEL_OUTPUT.new(),
|
|
],
|
|
tags={"order": TAGS_ORDER_HIGH},
|
|
)
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
|
|
class HOStreamingLLMOperator(BaseHOLLMOperator):
|
|
metadata = ViewMetadata(
|
|
label=_("Streaming LLM Operator"),
|
|
name="higher_order_streaming_llm_operator",
|
|
category=OperatorCategory.LLM,
|
|
description=_(
|
|
"High-level streaming LLM operator, supports multi-round conversation "
|
|
"(conversation window, token length and no multi-round)."
|
|
),
|
|
parameters=[
|
|
_PARAMETER_PROMPT_TEMPLATE.new(),
|
|
_PARAMETER_MODEL.new(),
|
|
_PARAMETER_LLM_CLIENT.new(),
|
|
_PARAMETER_HISTORY_MERGE_MODE.new(),
|
|
_PARAMETER_USER_MESSAGE_KEY.new(),
|
|
_PARAMETER_HISTORY_KEY.new(),
|
|
_PARAMETER_KEEP_START_ROUNDS.new(),
|
|
_PARAMETER_KEEP_END_ROUNDS.new(),
|
|
_PARAMETER_MAX_TOKEN_LIMIT.new(),
|
|
],
|
|
inputs=[
|
|
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
|
|
_INPUTS_EXTRA_CONTEXT.new(),
|
|
],
|
|
outputs=[
|
|
_OUTPUTS_STREAMING_MODEL_OUTPUT.new(),
|
|
],
|
|
tags={"order": TAGS_ORDER_HIGH},
|
|
)
|
|
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|