mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-11 05:49:22 +00:00
feat(awel): New MessageConverter and more AWEL operators (#1039)
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"command": "dbgpt_awel_data_analyst_code_fix",
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "'"$MODEL"'",
|
||||
"stream": false,
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_copilot_1234",
|
||||
@@ -37,43 +37,55 @@
|
||||
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from functools import cache
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import (
|
||||
InMemoryStorage,
|
||||
LLMClient,
|
||||
MessageStorageItem,
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
PromptManager,
|
||||
PromptTemplate,
|
||||
StorageConversation,
|
||||
StorageInterface,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.core.operator import (
|
||||
BufferedConversationMapperOperator,
|
||||
HistoryDynamicPromptBuilderOperator,
|
||||
LLMBranchOperator,
|
||||
RequestBuilderOperator,
|
||||
)
|
||||
from dbgpt.model.operator import (
|
||||
LLMOperator,
|
||||
PostConversationOperator,
|
||||
PostStreamingConversationOperator,
|
||||
PreConversationOperator,
|
||||
RequestBuildOperator,
|
||||
OpenAIStreamingOutputOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
|
||||
from dbgpt.util.utils import colored
|
||||
from dbgpt.serve.conversation.operator import ServePreChatHistoryLoadOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROMPT_LANG_ZH = "zh"
|
||||
PROMPT_LANG_EN = "en"
|
||||
|
||||
CODE_DEFAULT = "dbgpt_awel_data_analyst_code_default"
|
||||
CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
|
||||
CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
|
||||
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
|
||||
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
|
||||
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
|
||||
|
||||
CODE_DEFAULT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师。
|
||||
你可以根据最佳实践来优化代码, 也可以对代码进行修复, 解释, 添加注释, 以及将代码翻译成其他语言。"""
|
||||
CODE_DEFAULT_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst.
|
||||
You can optimize the code according to best practices, or fix, explain, add comments to the code,
|
||||
and you can also translate the code into other languages.
|
||||
"""
|
||||
|
||||
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
|
||||
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
|
||||
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
|
||||
@@ -126,7 +138,9 @@ class ReqContext(BaseModel):
|
||||
|
||||
class TriggerReqBody(BaseModel):
|
||||
messages: str = Field(..., description="User input messages")
|
||||
command: Optional[str] = Field(default="fix", description="Command name")
|
||||
command: Optional[str] = Field(
|
||||
default=None, description="Command name, None if common chat"
|
||||
)
|
||||
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
|
||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||
language: Optional[str] = Field(default="hive", description="Language")
|
||||
@@ -140,109 +154,89 @@ class TriggerReqBody(BaseModel):
|
||||
|
||||
@cache
|
||||
def load_or_save_prompt_template(pm: PromptManager):
|
||||
ext_params = {
|
||||
zh_ext_params = {
|
||||
"chat_scene": "chat_with_code",
|
||||
"sub_chat_scene": "data_analyst",
|
||||
"prompt_type": "common",
|
||||
"prompt_language": PROMPT_LANG_ZH,
|
||||
}
|
||||
en_ext_params = {
|
||||
"chat_scene": "chat_with_code",
|
||||
"sub_chat_scene": "data_analyst",
|
||||
"prompt_type": "common",
|
||||
"prompt_language": PROMPT_LANG_EN,
|
||||
}
|
||||
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_FIX_TEMPLATE_ZH,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_ZH),
|
||||
prompt_name=CODE_DEFAULT,
|
||||
**zh_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_EN),
|
||||
prompt_name=CODE_DEFAULT,
|
||||
**en_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate.from_template(CODE_FIX_TEMPLATE_ZH),
|
||||
prompt_name=CODE_FIX,
|
||||
prompt_language="zh",
|
||||
**ext_params,
|
||||
**zh_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_FIX_TEMPLATE_EN,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_FIX_TEMPLATE_EN),
|
||||
prompt_name=CODE_FIX,
|
||||
prompt_language="en",
|
||||
**ext_params,
|
||||
**en_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_PERF_TEMPLATE_ZH,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_PERF_TEMPLATE_ZH),
|
||||
prompt_name=CODE_PERF,
|
||||
prompt_language="zh",
|
||||
**ext_params,
|
||||
**zh_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_PERF_TEMPLATE_EN,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_PERF_TEMPLATE_EN),
|
||||
prompt_name=CODE_PERF,
|
||||
prompt_language="en",
|
||||
**ext_params,
|
||||
**en_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_EXPLAIN_TEMPLATE_ZH,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_ZH),
|
||||
prompt_name=CODE_EXPLAIN,
|
||||
prompt_language="zh",
|
||||
**ext_params,
|
||||
**zh_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_EXPLAIN_TEMPLATE_EN,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_EN),
|
||||
prompt_name=CODE_EXPLAIN,
|
||||
prompt_language="en",
|
||||
**ext_params,
|
||||
**en_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_COMMENT_TEMPLATE_ZH,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_ZH),
|
||||
prompt_name=CODE_COMMENT,
|
||||
prompt_language="zh",
|
||||
**ext_params,
|
||||
**zh_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["language"],
|
||||
template=CODE_COMMENT_TEMPLATE_EN,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_EN),
|
||||
prompt_name=CODE_COMMENT,
|
||||
prompt_language="en",
|
||||
**ext_params,
|
||||
**en_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["source_language", "target_language"],
|
||||
template=CODE_TRANSLATE_TEMPLATE_ZH,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_ZH),
|
||||
prompt_name=CODE_TRANSLATE,
|
||||
prompt_language="zh",
|
||||
**ext_params,
|
||||
**zh_ext_params,
|
||||
)
|
||||
pm.query_or_save(
|
||||
PromptTemplate(
|
||||
input_variables=["source_language", "target_language"],
|
||||
template=CODE_TRANSLATE_TEMPLATE_EN,
|
||||
),
|
||||
PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_EN),
|
||||
prompt_name=CODE_TRANSLATE,
|
||||
prompt_language="en",
|
||||
**ext_params,
|
||||
**en_ext_params,
|
||||
)
|
||||
|
||||
|
||||
class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
|
||||
class PromptTemplateBuilderOperator(MapOperator[TriggerReqBody, ChatPromptTemplate]):
|
||||
"""Build prompt template for chat with code."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._default_prompt_manager = PromptManager()
|
||||
|
||||
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
|
||||
async def map(self, input_value: TriggerReqBody) -> ChatPromptTemplate:
|
||||
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
|
||||
from dbgpt.serve.prompt.serve import Serve as PromptServe
|
||||
|
||||
@@ -256,7 +250,24 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
|
||||
load_or_save_prompt_template(pm)
|
||||
|
||||
user_language = self.system_app.config.get_current_lang(default="en")
|
||||
if not input_value.command:
|
||||
# No command, just chat, not include system prompt.
|
||||
default_prompt_list = pm.prefer_query(
|
||||
CODE_DEFAULT, prefer_prompt_language=user_language
|
||||
)
|
||||
default_prompt_template = (
|
||||
default_prompt_list[0].to_prompt_template().template
|
||||
)
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(default_prompt_template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{user_input}"),
|
||||
]
|
||||
)
|
||||
return prompt
|
||||
|
||||
# Query prompt template from prompt manager by command name
|
||||
prompt_list = pm.prefer_query(
|
||||
input_value.command, prefer_prompt_language=user_language
|
||||
)
|
||||
@@ -264,109 +275,38 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
|
||||
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
prompt = prompt_list[0].to_prompt_template()
|
||||
if input_value.command == CODE_TRANSLATE:
|
||||
format_params = {
|
||||
"source_language": input_value.language,
|
||||
"target_language": input_value.target_language,
|
||||
}
|
||||
else:
|
||||
format_params = {"language": input_value.language}
|
||||
prompt_template = prompt_list[0].to_prompt_template()
|
||||
|
||||
system_message = prompt.format(**format_params)
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message),
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages),
|
||||
]
|
||||
context = input_value.context.dict() if input_value.context else {}
|
||||
return {
|
||||
"messages": messages,
|
||||
"stream": input_value.stream,
|
||||
"model": input_value.model,
|
||||
"context": context,
|
||||
}
|
||||
|
||||
|
||||
class MyConversationOperator(PreConversationOperator):
|
||||
def __init__(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage, message_storage, **kwargs)
|
||||
|
||||
def _get_conversion_serve(self):
|
||||
from dbgpt.serve.conversation.serve import (
|
||||
SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME,
|
||||
return ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template(prompt_template.template),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{user_input}"),
|
||||
]
|
||||
)
|
||||
from dbgpt.serve.conversation.serve import Serve as ConversationServe
|
||||
|
||||
conversation_serve: ConversationServe = self.system_app.get_component(
|
||||
CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None
|
||||
)
|
||||
return conversation_serve
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
if self._storage:
|
||||
return self._storage
|
||||
conversation_serve = self._get_conversion_serve()
|
||||
if conversation_serve:
|
||||
return conversation_serve.conv_storage
|
||||
else:
|
||||
logger.info("Conversation storage not found, use InMemoryStorage default")
|
||||
self._storage = InMemoryStorage()
|
||||
return self._storage
|
||||
|
||||
@property
|
||||
def message_storage(self):
|
||||
if self._message_storage:
|
||||
return self._message_storage
|
||||
conversation_serve = self._get_conversion_serve()
|
||||
if conversation_serve:
|
||||
return conversation_serve.message_storage
|
||||
else:
|
||||
logger.info("Message storage not found, use InMemoryStorage default")
|
||||
self._message_storage = InMemoryStorage()
|
||||
return self._message_storage
|
||||
|
||||
|
||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
||||
def parse_prompt_args(req: TriggerReqBody) -> Dict[str, Any]:
|
||||
prompt_args = {"user_input": req.messages}
|
||||
if not req.command:
|
||||
return prompt_args
|
||||
if req.command == CODE_TRANSLATE:
|
||||
prompt_args["source_language"] = req.language
|
||||
prompt_args["target_language"] = req.target_language
|
||||
else:
|
||||
prompt_args["language"] = req.language
|
||||
return prompt_args
|
||||
|
||||
|
||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
def history_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
"""Mapper for history conversation.
|
||||
|
||||
If there are multi system messages, just keep the first system message.
|
||||
"""
|
||||
has_system_message = False
|
||||
mapper_messages = []
|
||||
for messages in messages_by_round:
|
||||
for message in messages:
|
||||
if message.role == ModelMessageRoleType.SYSTEM:
|
||||
if has_system_message:
|
||||
continue
|
||||
else:
|
||||
mapper_messages.append(message)
|
||||
has_system_message = True
|
||||
else:
|
||||
mapper_messages.append(message)
|
||||
print("history_message_mapper start:" + "=" * 70)
|
||||
print(colored(ModelMessage.get_printable_message(mapper_messages), "green"))
|
||||
print("history_message_mapper end:" + "=" * 72)
|
||||
return mapper_messages
|
||||
async def build_model_request(
|
||||
messages: List[ModelMessage], req_body: TriggerReqBody
|
||||
) -> ModelRequest:
|
||||
return ModelRequest.build_request(
|
||||
model=req_body.model,
|
||||
messages=messages,
|
||||
context=req_body.context,
|
||||
stream=req_body.stream,
|
||||
)
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_data_analyst_assistant") as dag:
|
||||
@@ -377,57 +317,59 @@ with DAG("dbgpt_awel_data_analyst_assistant") as dag:
|
||||
streaming_predict_func=lambda x: x.stream,
|
||||
)
|
||||
|
||||
copilot_task = CopilotOperator()
|
||||
request_handle_task = RequestBuildOperator()
|
||||
prompt_template_load_task = PromptTemplateBuilderOperator()
|
||||
request_handle_task = RequestBuilderOperator()
|
||||
|
||||
# Pre-process conversation
|
||||
pre_conversation_task = MyConversationOperator()
|
||||
# Keep last k round conversation.
|
||||
history_conversation_task = BufferedConversationMapperOperator(
|
||||
last_k_round=5, message_mapper=history_message_mapper
|
||||
# Load and store chat history
|
||||
chat_history_load_task = ServePreChatHistoryLoadOperator()
|
||||
last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5))
|
||||
# History transform task, here we keep last k round messages
|
||||
history_transform_task = BufferedConversationMapperOperator(
|
||||
last_k_round=last_k_round
|
||||
)
|
||||
history_prompt_build_task = HistoryDynamicPromptBuilderOperator(
|
||||
history_key="chat_history"
|
||||
)
|
||||
|
||||
# Save conversation to storage.
|
||||
post_conversation_task = PostConversationOperator()
|
||||
# Save streaming conversation to storage.
|
||||
post_streaming_conversation_task = PostStreamingConversationOperator()
|
||||
model_request_build_task = JoinOperator(build_model_request)
|
||||
|
||||
# Use LLMOperator to generate response.
|
||||
llm_task = MyLLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
||||
# Use BaseLLMOperator to generate response.
|
||||
llm_task = LLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
|
||||
branch_task = LLMBranchOperator(
|
||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||
)
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
openai_format_stream_task = OpenAIStreamingOperator()
|
||||
openai_format_stream_task = OpenAIStreamingOutputOperator()
|
||||
result_join_task = JoinOperator(
|
||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||
)
|
||||
trigger >> prompt_template_load_task >> history_prompt_build_task
|
||||
|
||||
(
|
||||
trigger
|
||||
>> copilot_task
|
||||
>> request_handle_task
|
||||
>> pre_conversation_task
|
||||
>> history_conversation_task
|
||||
>> branch_task
|
||||
>> MapOperator(
|
||||
lambda req: ModelRequestContext(
|
||||
conv_uid=req.context.conv_uid,
|
||||
stream=req.stream,
|
||||
chat_mode=req.context.chat_mode,
|
||||
)
|
||||
)
|
||||
>> chat_history_load_task
|
||||
>> history_transform_task
|
||||
>> history_prompt_build_task
|
||||
)
|
||||
|
||||
trigger >> MapOperator(parse_prompt_args) >> history_prompt_build_task
|
||||
|
||||
history_prompt_build_task >> model_request_build_task
|
||||
trigger >> model_request_build_task
|
||||
|
||||
model_request_build_task >> branch_task
|
||||
# The branch of no streaming response.
|
||||
(
|
||||
branch_task
|
||||
>> llm_task
|
||||
>> post_conversation_task
|
||||
>> model_parse_task
|
||||
>> result_join_task
|
||||
)
|
||||
(branch_task >> llm_task >> model_parse_task >> result_join_task)
|
||||
# The branch of streaming response.
|
||||
(
|
||||
branch_task
|
||||
>> streaming_llm_task
|
||||
>> post_streaming_conversation_task
|
||||
>> openai_format_stream_task
|
||||
>> result_join_task
|
||||
)
|
||||
(branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
|
@@ -12,7 +12,7 @@
|
||||
# Fist round
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "'"$MODEL"'",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_1234"
|
||||
},
|
||||
@@ -22,7 +22,7 @@
|
||||
# Second round
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "'"$MODEL"'",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_1234"
|
||||
},
|
||||
@@ -34,7 +34,7 @@
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "'"$MODEL"'",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_stream_1234"
|
||||
},
|
||||
@@ -45,7 +45,7 @@
|
||||
# Second round
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"model": "'"$MODEL"'",
|
||||
"context": {
|
||||
"conv_uid": "uuid_conv_stream_1234"
|
||||
},
|
||||
@@ -59,19 +59,27 @@ import logging
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import InMemoryStorage, LLMClient
|
||||
from dbgpt.core import (
|
||||
ChatPromptTemplate,
|
||||
HumanPromptTemplate,
|
||||
InMemoryStorage,
|
||||
MessagesPlaceholder,
|
||||
ModelMessage,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
SystemPromptTemplate,
|
||||
)
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.core.operator import (
|
||||
BufferedConversationMapperOperator,
|
||||
ChatComposerInput,
|
||||
ChatHistoryPromptComposerOperator,
|
||||
LLMBranchOperator,
|
||||
)
|
||||
from dbgpt.model.operator import (
|
||||
LLMOperator,
|
||||
PostConversationOperator,
|
||||
PostStreamingConversationOperator,
|
||||
PreConversationOperator,
|
||||
RequestBuildOperator,
|
||||
OpenAIStreamingOutputOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -100,16 +108,15 @@ class TriggerReqBody(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
async def build_model_request(
|
||||
messages: List[ModelMessage], req_body: TriggerReqBody
|
||||
) -> ModelRequest:
|
||||
return ModelRequest.build_request(
|
||||
model=req_body.model,
|
||||
messages=messages,
|
||||
context=req_body.context,
|
||||
stream=req_body.stream,
|
||||
)
|
||||
|
||||
|
||||
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
|
||||
@@ -120,56 +127,53 @@ with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
|
||||
request_body=TriggerReqBody,
|
||||
streaming_predict_func=lambda req: req.stream,
|
||||
)
|
||||
# Transform request body to model request.
|
||||
request_handle_task = RequestBuildOperator()
|
||||
# Pre-process conversation, use InMemoryStorage to store conversation.
|
||||
pre_conversation_task = PreConversationOperator(
|
||||
storage=InMemoryStorage(), message_storage=InMemoryStorage()
|
||||
prompt = ChatPromptTemplate(
|
||||
messages=[
|
||||
SystemPromptTemplate.from_template("You are a helpful chatbot."),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
HumanPromptTemplate.from_template("{user_input}"),
|
||||
]
|
||||
)
|
||||
# Keep last k round conversation.
|
||||
history_conversation_task = BufferedConversationMapperOperator(last_k_round=5)
|
||||
|
||||
# Save conversation to storage.
|
||||
post_conversation_task = PostConversationOperator()
|
||||
# Save streaming conversation to storage.
|
||||
post_streaming_conversation_task = PostStreamingConversationOperator()
|
||||
composer_operator = ChatHistoryPromptComposerOperator(
|
||||
prompt_template=prompt,
|
||||
last_k_round=5,
|
||||
storage=InMemoryStorage(),
|
||||
message_storage=InMemoryStorage(),
|
||||
)
|
||||
|
||||
# Use LLMOperator to generate response.
|
||||
llm_task = MyLLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
||||
# Use BaseLLMOperator to generate response.
|
||||
llm_task = LLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
|
||||
branch_task = LLMBranchOperator(
|
||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||
)
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
openai_format_stream_task = OpenAIStreamingOperator()
|
||||
openai_format_stream_task = OpenAIStreamingOutputOperator()
|
||||
result_join_task = JoinOperator(
|
||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||
)
|
||||
|
||||
(
|
||||
trigger
|
||||
>> request_handle_task
|
||||
>> pre_conversation_task
|
||||
>> history_conversation_task
|
||||
>> branch_task
|
||||
req_handle_task = MapOperator(
|
||||
lambda req: ChatComposerInput(
|
||||
context=ModelRequestContext(
|
||||
conv_uid=req.context.conv_uid, stream=req.stream
|
||||
),
|
||||
prompt_dict={"user_input": req.messages},
|
||||
model_dict={
|
||||
"model": req.model,
|
||||
"context": req.context,
|
||||
"stream": req.stream,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
trigger >> req_handle_task >> composer_operator >> branch_task
|
||||
|
||||
# The branch of no streaming response.
|
||||
(
|
||||
branch_task
|
||||
>> llm_task
|
||||
>> post_conversation_task
|
||||
>> model_parse_task
|
||||
>> result_join_task
|
||||
)
|
||||
branch_task >> llm_task >> model_parse_task >> result_join_task
|
||||
# The branch of streaming response.
|
||||
(
|
||||
branch_task
|
||||
>> streaming_llm_task
|
||||
>> post_streaming_conversation_task
|
||||
>> openai_format_stream_task
|
||||
>> result_join_task
|
||||
)
|
||||
branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task
|
||||
|
||||
if __name__ == "__main__":
|
||||
if multi_round_dag.leaf_nodes[0].dev_mode:
|
||||
|
@@ -31,3 +31,11 @@ with DAG("simple_dag_example") as dag:
|
||||
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
|
||||
map_node = RequestHandleOperator()
|
||||
trigger >> map_node
|
||||
|
||||
if __name__ == "__main__":
|
||||
if dag.leaf_nodes[0].dev_mode:
|
||||
from dbgpt.core.awel import setup_dev_environment
|
||||
|
||||
setup_dev_environment([dag])
|
||||
else:
|
||||
pass
|
||||
|
@@ -8,9 +8,10 @@
|
||||
.. code-block:: shell
|
||||
|
||||
DBGPT_SERVER="http://127.0.0.1:5555"
|
||||
MODEL="gpt-3.5-turbo"
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"model": "'"$MODEL"'",
|
||||
"messages": "hello"
|
||||
}'
|
||||
|
||||
@@ -19,7 +20,7 @@
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"model": "'"$MODEL"'",
|
||||
"messages": "hello",
|
||||
"stream": true
|
||||
}'
|
||||
@@ -29,7 +30,7 @@
|
||||
|
||||
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
|
||||
-H "Content-Type: application/json" -d '{
|
||||
"model": "proxyllm",
|
||||
"model": "'"$MODEL"'",
|
||||
"messages": "hello"
|
||||
}'
|
||||
|
||||
@@ -40,13 +41,13 @@ from typing import Any, Dict, List, Optional, Union
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core import LLMClient
|
||||
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
|
||||
from dbgpt.core.operator import (
|
||||
LLMBranchOperator,
|
||||
from dbgpt.core.operator import LLMBranchOperator, RequestBuilderOperator
|
||||
from dbgpt.model.operator import (
|
||||
LLMOperator,
|
||||
RequestBuildOperator,
|
||||
MixinLLMOperator,
|
||||
OpenAIStreamingOutputOperator,
|
||||
StreamingLLMOperator,
|
||||
)
|
||||
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,18 +60,6 @@ class TriggerReqBody(BaseModel):
|
||||
stream: Optional[bool] = Field(default=False, description="Whether return stream")
|
||||
|
||||
|
||||
class MyLLMOperator(MixinLLMOperator, LLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
LLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
|
||||
super().__init__(llm_client)
|
||||
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
|
||||
|
||||
|
||||
class MyModelToolOperator(
|
||||
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
|
||||
):
|
||||
@@ -97,14 +86,14 @@ with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
|
||||
request_body=TriggerReqBody,
|
||||
streaming_predict_func=lambda req: req.stream,
|
||||
)
|
||||
request_handle_task = RequestBuildOperator()
|
||||
llm_task = MyLLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
|
||||
request_handle_task = RequestBuilderOperator()
|
||||
llm_task = LLMOperator(task_name="llm_task")
|
||||
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
|
||||
branch_task = LLMBranchOperator(
|
||||
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
|
||||
)
|
||||
model_parse_task = MapOperator(lambda out: out.to_dict())
|
||||
openai_format_stream_task = OpenAIStreamingOperator()
|
||||
openai_format_stream_task = OpenAIStreamingOutputOperator()
|
||||
result_join_task = JoinOperator(
|
||||
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
|
||||
)
|
||||
|
Reference in New Issue
Block a user