DB-GPT/dbgpt/app/scene/operators/app_operator.py
2024-06-18 11:11:43 +08:00

205 lines
7.4 KiB
Python

import dataclasses
from typing import Any, Dict, List, Optional
from dbgpt import SystemApp
from dbgpt.component import ComponentType
from dbgpt.core import (
BaseMessage,
ChatPromptTemplate,
LLMClient,
ModelRequest,
ModelRequestContext,
)
from dbgpt.core.awel import (
DAG,
BaseOperator,
BranchJoinOperator,
InputOperator,
MapOperator,
SimpleCallDataInputSource,
)
from dbgpt.core.operators import (
BufferedConversationMapperOperator,
HistoryPromptBuilderOperator,
)
from dbgpt.model.operators import LLMOperator, StreamingLLMOperator
from dbgpt.storage.cache.operators import (
CachedModelOperator,
CachedModelStreamOperator,
CacheManager,
ModelCacheBranchOperator,
ModelSaveCacheOperator,
ModelStreamSaveCacheOperator,
)
@dataclasses.dataclass
class ChatComposerInput:
"""The composer input."""
messages: List[BaseMessage]
prompt_dict: Dict[str, Any]
class AppChatComposerOperator(MapOperator[ChatComposerInput, ModelRequest]):
"""App chat composer operator.
TODO: Support more history merge mode.
"""
def __init__(
self,
model: str,
temperature: float,
max_new_tokens: int,
prompt: ChatPromptTemplate,
message_version: str = "v2",
echo: bool = False,
streaming: bool = True,
history_key: str = "chat_history",
history_merge_mode: str = "window",
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
str_history: bool = False,
request_context: ModelRequestContext = None,
**kwargs,
):
super().__init__(**kwargs)
if not request_context:
request_context = ModelRequestContext(stream=streaming)
self._prompt_template = prompt
self._history_key = history_key
self._history_merge_mode = history_merge_mode
self._keep_start_rounds = keep_start_rounds
self._keep_end_rounds = keep_end_rounds
self._str_history = str_history
self._model_name = model
self._temperature = temperature
self._max_new_tokens = max_new_tokens
self._message_version = message_version
self._echo = echo
self._streaming = streaming
self._request_context = request_context
self._sub_compose_dag = self._build_composer_dag()
async def map(self, input_value: ChatComposerInput) -> ModelRequest:
end_node: BaseOperator = self._sub_compose_dag.leaf_nodes[0]
# Sub dag, use the same dag context in the parent dag
messages = await end_node.call(
call_data=input_value, dag_ctx=self.current_dag_context
)
span_id = self._request_context.span_id
model_request = ModelRequest.build_request(
model=self._model_name,
messages=messages,
context=self._request_context,
temperature=self._temperature,
max_new_tokens=self._max_new_tokens,
span_id=span_id,
echo=self._echo,
)
return model_request
def _build_composer_dag(self) -> DAG:
with DAG("dbgpt_awel_app_chat_history_prompt_composer") as composer_dag:
input_task = InputOperator(input_source=SimpleCallDataInputSource())
# History transform task
history_transform_task = BufferedConversationMapperOperator(
keep_start_rounds=self._keep_start_rounds,
keep_end_rounds=self._keep_end_rounds,
)
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template,
history_key=self._history_key,
check_storage=False,
str_history=self._str_history,
)
# Build composer dag
(
input_task
>> MapOperator(lambda x: x.messages)
>> history_transform_task
>> history_prompt_build_task
)
(
input_task
>> MapOperator(lambda x: x.prompt_dict)
>> history_prompt_build_task
)
return composer_dag
def build_cached_chat_operator(
llm_client: LLMClient,
is_streaming: bool,
system_app: SystemApp,
cache_manager: Optional[CacheManager] = None,
):
"""Builds and returns a model processing workflow (DAG) operator.
This function constructs a Directed Acyclic Graph (DAG) for processing data using a model.
It includes caching and branching logic to either fetch results from a cache or process
data using the model. It supports both streaming and non-streaming modes.
.. code-block:: python
input_task >> cache_check_branch_task
cache_check_branch_task >> llm_task >> save_cache_task >> join_task
cache_check_branch_task >> cache_task >> join_task
equivalent to::
-> llm_task -> save_cache_task ->
/ \
input_task -> cache_check_branch_task ---> join_task
\ /
-> cache_task ------------------- ->
Args:
llm_client (LLMClient): The LLM client for processing data using the model.
is_streaming (bool): Whether the model is a streaming model.
system_app (SystemApp): The system app.
cache_manager (CacheManager, optional): The cache manager for managing cache operations. Defaults to None.
Returns:
BaseOperator: The final operator in the constructed DAG, typically a join node.
"""
# Define task names for the model and cache nodes
model_task_name = "llm_model_node"
cache_task_name = "llm_model_cache_node"
if not cache_manager:
cache_manager: CacheManager = system_app.get_component(
ComponentType.MODEL_CACHE_MANAGER, CacheManager
)
with DAG("dbgpt_awel_app_model_infer_with_cached") as dag:
# Create an input task
input_task = InputOperator(SimpleCallDataInputSource())
# Create a branch task to decide between fetching from cache or processing with the model
if is_streaming:
llm_task = StreamingLLMOperator(llm_client, task_name=model_task_name)
cache_task = CachedModelStreamOperator(
cache_manager, task_name=cache_task_name
)
save_cache_task = ModelStreamSaveCacheOperator(cache_manager)
else:
llm_task = LLMOperator(llm_client, task_name=model_task_name)
cache_task = CachedModelOperator(cache_manager, task_name=cache_task_name)
save_cache_task = ModelSaveCacheOperator(cache_manager)
# Create a branch node to decide between fetching from cache or processing with the model
cache_check_branch_task = ModelCacheBranchOperator(
cache_manager,
model_task_name=model_task_name,
cache_task_name=cache_task_name,
)
# Create a join node to merge outputs from the model and cache nodes, just keep the first not empty output
join_task = BranchJoinOperator()
# Define the workflow structure using the >> operator
input_task >> cache_check_branch_task
cache_check_branch_task >> llm_task >> save_cache_task >> join_task
cache_check_branch_task >> cache_task >> join_task
return join_task