DB-GPT/dbgpt/app/operators/llm.py
2024-08-29 19:39:42 +08:00

444 lines
15 KiB
Python

from typing import List, Literal, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
BaseMessage,
ChatPromptTemplate,
LLMClient,
ModelOutput,
ModelRequest,
StorageConversation,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
CommonLLMHttpRequestBody,
DAGContext,
DefaultInputContext,
InputOperator,
JoinOperator,
MapOperator,
SimpleCallDataInputSource,
TaskOutput,
)
from dbgpt.core.awel.flow import (
TAGS_ORDER_HIGH,
IOField,
OperatorCategory,
OptionValue,
Parameter,
ViewMetadata,
ui,
)
from dbgpt.core.interface.operators.message_operator import (
BaseConversationOperator,
BufferedConversationMapperOperator,
TokenBufferedConversationMapperOperator,
)
from dbgpt.core.interface.operators.prompt_operator import HistoryPromptBuilderOperator
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
from dbgpt.serve.conversation.serve import Serve as ConversationServe
from dbgpt.util.i18n_utils import _
from dbgpt.util.tracer import root_tracer
class HOContextBody(BaseModel):
"""Higher-order context body."""
context_key: str = Field(
"context",
description=_("The context key can be used as the key for formatting prompt."),
)
context: Union[str, List[str]] = Field(
...,
description=_("The context."),
)
class BaseHOLLMOperator(
BaseConversationOperator,
JoinOperator[ModelRequest],
LLMOperator,
StreamingLLMOperator,
):
"""Higher-order model request builder operator."""
def __init__(
self,
prompt_template: ChatPromptTemplate,
model: str = None,
llm_client: Optional[LLMClient] = None,
history_merge_mode: Literal["none", "window", "token"] = "window",
user_message_key: str = "user_input",
history_key: Optional[str] = None,
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
max_token_limit: int = 2048,
**kwargs,
):
JoinOperator.__init__(self, combine_function=self._join_func, **kwargs)
LLMOperator.__init__(self, llm_client=llm_client, **kwargs)
StreamingLLMOperator.__init__(self, llm_client=llm_client, **kwargs)
# User must select a history merge mode
self._history_merge_mode = history_merge_mode
self._user_message_key = user_message_key
self._has_history = history_merge_mode != "none"
self._prompt_template = prompt_template
self._model = model
self._history_key = history_key
self._str_history = False
self._keep_start_rounds = keep_start_rounds if self._has_history else 0
self._keep_end_rounds = keep_end_rounds if self._has_history else 0
self._max_token_limit = max_token_limit
self._sub_compose_dag = self._build_conversation_composer_dag()
async def _do_run(self, dag_ctx: DAGContext) -> TaskOutput[ModelOutput]:
conv_serve = ConversationServe.get_instance(self.system_app)
self._storage = conv_serve.conv_storage
self._message_storage = conv_serve.message_storage
_: TaskOutput[ModelRequest] = await JoinOperator._do_run(self, dag_ctx)
dag_ctx.current_task_context.set_task_input(
DefaultInputContext([dag_ctx.current_task_context])
)
if dag_ctx.streaming_call:
task_output = await StreamingLLMOperator._do_run(self, dag_ctx)
else:
task_output = await LLMOperator._do_run(self, dag_ctx)
return task_output
async def after_dag_end(self, event_loop_task_id: int):
model_output: Optional[
ModelOutput
] = await self.current_dag_context.get_from_share_data(
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT
)
model_output_view: Optional[
str
] = await self.current_dag_context.get_from_share_data(
LLMOperator.SHARE_DATA_KEY_MODEL_OUTPUT_VIEW
)
storage_conv = await self.get_storage_conversation()
end_current_round: bool = False
if model_output and storage_conv:
# Save model output message to storage
storage_conv.add_ai_message(model_output.text)
end_current_round = True
if model_output_view and storage_conv:
# Save model output view to storage
storage_conv.add_view_message(model_output_view)
end_current_round = True
if end_current_round:
# End current conversation round and flush to storage
storage_conv.end_current_round()
async def _join_func(self, req: CommonLLMHttpRequestBody, *args):
dynamic_inputs = []
for arg in args:
if isinstance(arg, HOContextBody):
dynamic_inputs.append(arg)
# Load and store chat history, default use InMemoryStorage.
storage_conv, history_messages = await self.blocking_func_to_async(
self._build_storage, req
)
# Save the storage conversation to share data, for the child operators
await self.current_dag_context.save_to_share_data(
self.SHARE_DATA_KEY_STORAGE_CONVERSATION, storage_conv
)
user_input = (
req.messages[-1] if isinstance(req.messages, list) else req.messages
)
prompt_dict = {
self._user_message_key: user_input,
}
for dynamic_input in dynamic_inputs:
if dynamic_input.context_key in prompt_dict:
raise ValueError(
f"Duplicate context key '{dynamic_input.context_key}' in upstream "
f"operators."
)
prompt_dict[dynamic_input.context_key] = dynamic_input.context
call_data = {
"messages": history_messages,
"prompt_dict": prompt_dict,
}
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(call_data, dag_ctx=self.current_dag_context)
model_request = ModelRequest.build_request(
model=req.model,
messages=messages,
context=req.context,
temperature=req.temperature,
max_new_tokens=req.max_new_tokens,
span_id=root_tracer.get_current_span_id(),
echo=False,
)
if storage_conv:
# Start new round
storage_conv.start_new_round()
storage_conv.add_user_message(user_input)
return model_request
def _build_storage(
self, req: CommonLLMHttpRequestBody
) -> Tuple[StorageConversation, List[BaseMessage]]:
# Create a new storage conversation, this will load the conversation from
# storage, so we must do this async
storage_conv: StorageConversation = StorageConversation(
conv_uid=req.conv_uid,
chat_mode=req.chat_mode,
user_name=req.user_name,
sys_code=req.sys_code,
conv_storage=self.storage,
message_storage=self.message_storage,
param_type="",
param_value=req.chat_param,
)
# Get history messages from storage
history_messages: List[BaseMessage] = storage_conv.get_history_message(
include_system_message=False
)
return storage_conv, history_messages
def _build_conversation_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# History transform task
if self._history_merge_mode == "token":
history_transform_task = TokenBufferedConversationMapperOperator(
model=self._model,
llm_client=self.llm_client,
max_token_limit=self._max_token_limit,
)
else:
history_transform_task = BufferedConversationMapperOperator(
keep_start_rounds=self._keep_start_rounds,
keep_end_rounds=self._keep_end_rounds,
)
if self._history_key:
history_key = self._history_key
else:
placeholders = self._prompt_template.get_placeholders()
if not placeholders or len(placeholders) != 1:
raise ValueError(
"The prompt template must have exactly one placeholder if "
"history_key is not provided."
)
history_key = placeholders[0]
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template,
history_key=history_key,
check_storage=False,
save_to_storage=False,
str_history=self._str_history,
)
# Build composer dag
(
input_task
>> MapOperator(lambda x: x["messages"])
>> history_transform_task
>> history_prompt_build_task
)
(
input_task
>> MapOperator(lambda x: x["prompt_dict"])
>> history_prompt_build_task
)
return composer_dag
_PARAMETER_PROMPT_TEMPLATE = Parameter.build_from(
_("Prompt Template"),
"prompt_template",
ChatPromptTemplate,
description=_("The prompt template for the conversation."),
)
_PARAMETER_MODEL = Parameter.build_from(
_("Model Name"),
"model",
str,
optional=True,
default=None,
description=_("The model name."),
)
_PARAMETER_LLM_CLIENT = Parameter.build_from(
_("LLM Client"),
"llm_client",
LLMClient,
optional=True,
default=None,
description=_(
"The LLM Client, how to connect to the LLM model, if not provided, it will use"
" the default client deployed by DB-GPT."
),
)
_PARAMETER_HISTORY_MERGE_MODE = Parameter.build_from(
_("History Message Merge Mode"),
"history_merge_mode",
str,
optional=True,
default="none",
options=[
OptionValue(label="No History", name="none", value="none"),
OptionValue(label="Message Window", name="window", value="window"),
OptionValue(label="Token Length", name="token", value="token"),
],
description=_(
"The history merge mode, supports 'none', 'window' and 'token'."
" 'none': no history merge, 'window': merge by conversation window, 'token': "
"merge by token length."
),
ui=ui.UISelect(),
)
_PARAMETER_USER_MESSAGE_KEY = Parameter.build_from(
_("User Message Key"),
"user_message_key",
str,
optional=True,
default="user_input",
description=_(
"The key of the user message in your prompt, default is 'user_input'."
),
)
_PARAMETER_HISTORY_KEY = Parameter.build_from(
_("History Key"),
"history_key",
str,
optional=True,
default=None,
description=_(
"The chat history key, with chat history message pass to prompt template, "
"if not provided, it will parse the prompt template to get the key."
),
)
_PARAMETER_KEEP_START_ROUNDS = Parameter.build_from(
_("Keep Start Rounds"),
"keep_start_rounds",
int,
optional=True,
default=None,
description=_("The start rounds to keep in the chat history."),
)
_PARAMETER_KEEP_END_ROUNDS = Parameter.build_from(
_("Keep End Rounds"),
"keep_end_rounds",
int,
optional=True,
default=None,
description=_("The end rounds to keep in the chat history."),
)
_PARAMETER_MAX_TOKEN_LIMIT = Parameter.build_from(
_("Max Token Limit"),
"max_token_limit",
int,
optional=True,
default=2048,
description=_("The max token limit to keep in the chat history."),
)
_INPUTS_COMMON_LLM_REQUEST_BODY = IOField.build_from(
_("Common LLM Request Body"),
"common_llm_request_body",
CommonLLMHttpRequestBody,
_("The common LLM request body."),
)
_INPUTS_EXTRA_CONTEXT = IOField.build_from(
_("Extra Context"),
"extra_context",
HOContextBody,
_(
"Extra context for building prompt(Knowledge context, database "
"schema, etc), you can add multiple context."
),
dynamic=True,
)
_OUTPUTS_MODEL_OUTPUT = IOField.build_from(
_("Model Output"),
"model_output",
ModelOutput,
description=_("The model output."),
)
_OUTPUTS_STREAMING_MODEL_OUTPUT = IOField.build_from(
_("Streaming Model Output"),
"streaming_model_output",
ModelOutput,
is_list=True,
description=_("The streaming model output."),
)
class HOLLMOperator(BaseHOLLMOperator):
metadata = ViewMetadata(
label=_("LLM Operator"),
name="higher_order_llm_operator",
category=OperatorCategory.LLM,
description=_(
"High-level LLM operator, supports multi-round conversation "
"(conversation window, token length and no multi-round)."
),
parameters=[
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_MODEL.new(),
_PARAMETER_LLM_CLIENT.new(),
_PARAMETER_HISTORY_MERGE_MODE.new(),
_PARAMETER_USER_MESSAGE_KEY.new(),
_PARAMETER_HISTORY_KEY.new(),
_PARAMETER_KEEP_START_ROUNDS.new(),
_PARAMETER_KEEP_END_ROUNDS.new(),
_PARAMETER_MAX_TOKEN_LIMIT.new(),
],
inputs=[
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
_INPUTS_EXTRA_CONTEXT.new(),
],
outputs=[
_OUTPUTS_MODEL_OUTPUT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)
class HOStreamingLLMOperator(BaseHOLLMOperator):
metadata = ViewMetadata(
label=_("Streaming LLM Operator"),
name="higher_order_streaming_llm_operator",
category=OperatorCategory.LLM,
description=_(
"High-level streaming LLM operator, supports multi-round conversation "
"(conversation window, token length and no multi-round)."
),
parameters=[
_PARAMETER_PROMPT_TEMPLATE.new(),
_PARAMETER_MODEL.new(),
_PARAMETER_LLM_CLIENT.new(),
_PARAMETER_HISTORY_MERGE_MODE.new(),
_PARAMETER_USER_MESSAGE_KEY.new(),
_PARAMETER_HISTORY_KEY.new(),
_PARAMETER_KEEP_START_ROUNDS.new(),
_PARAMETER_KEEP_END_ROUNDS.new(),
_PARAMETER_MAX_TOKEN_LIMIT.new(),
],
inputs=[
_INPUTS_COMMON_LLM_REQUEST_BODY.new(),
_INPUTS_EXTRA_CONTEXT.new(),
],
outputs=[
_OUTPUTS_STREAMING_MODEL_OUTPUT.new(),
],
tags={"order": TAGS_ORDER_HIGH},
)
def __init__(self, **kwargs):
super().__init__(**kwargs)