feat(awel): New MessageConverter and more AWEL operators (#1039)

This commit is contained in:
Fangyin Cheng
2024-01-08 09:40:05 +08:00
committed by GitHub
parent 765fb181f6
commit e8861bd8fa
48 changed files with 2333 additions and 719 deletions

View File

@@ -26,7 +26,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
-H "Content-Type: application/json" -d '{
"command": "dbgpt_awel_data_analyst_code_fix",
"model": "gpt-3.5-turbo",
"model": "'"$MODEL"'",
"stream": false,
"context": {
"conv_uid": "uuid_conv_copilot_1234",
@@ -37,43 +37,55 @@
"""
import logging
import os
from functools import cache
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
InMemoryStorage,
LLMClient,
MessageStorageItem,
ChatPromptTemplate,
HumanPromptTemplate,
MessagesPlaceholder,
ModelMessage,
ModelMessageRoleType,
ModelRequest,
ModelRequestContext,
PromptManager,
PromptTemplate,
StorageConversation,
StorageInterface,
SystemPromptTemplate,
)
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import (
BufferedConversationMapperOperator,
HistoryDynamicPromptBuilderOperator,
LLMBranchOperator,
RequestBuilderOperator,
)
from dbgpt.model.operator import (
LLMOperator,
PostConversationOperator,
PostStreamingConversationOperator,
PreConversationOperator,
RequestBuildOperator,
OpenAIStreamingOutputOperator,
StreamingLLMOperator,
)
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
from dbgpt.util.utils import colored
from dbgpt.serve.conversation.operator import ServePreChatHistoryLoadOperator
logger = logging.getLogger(__name__)
PROMPT_LANG_ZH = "zh"
PROMPT_LANG_EN = "en"
CODE_DEFAULT = "dbgpt_awel_data_analyst_code_default"
CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
CODE_DEFAULT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师。
你可以根据最佳实践来优化代码, 也可以对代码进行修复, 解释, 添加注释, 以及将代码翻译成其他语言。"""
CODE_DEFAULT_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst.
You can optimize the code according to best practices, or fix, explain, add comments to the code,
and you can also translate the code into other languages.
"""
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
@@ -126,7 +138,9 @@ class ReqContext(BaseModel):
class TriggerReqBody(BaseModel):
messages: str = Field(..., description="User input messages")
command: Optional[str] = Field(default="fix", description="Command name")
command: Optional[str] = Field(
default=None, description="Command name, None if common chat"
)
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
stream: Optional[bool] = Field(default=False, description="Whether return stream")
language: Optional[str] = Field(default="hive", description="Language")
@@ -140,109 +154,89 @@ class TriggerReqBody(BaseModel):
@cache
def load_or_save_prompt_template(pm: PromptManager):
ext_params = {
zh_ext_params = {
"chat_scene": "chat_with_code",
"sub_chat_scene": "data_analyst",
"prompt_type": "common",
"prompt_language": PROMPT_LANG_ZH,
}
en_ext_params = {
"chat_scene": "chat_with_code",
"sub_chat_scene": "data_analyst",
"prompt_type": "common",
"prompt_language": PROMPT_LANG_EN,
}
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_FIX_TEMPLATE_ZH,
),
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_ZH),
prompt_name=CODE_DEFAULT,
**zh_ext_params,
)
pm.query_or_save(
PromptTemplate.from_template(CODE_DEFAULT_TEMPLATE_EN),
prompt_name=CODE_DEFAULT,
**en_ext_params,
)
pm.query_or_save(
PromptTemplate.from_template(CODE_FIX_TEMPLATE_ZH),
prompt_name=CODE_FIX,
prompt_language="zh",
**ext_params,
**zh_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_FIX_TEMPLATE_EN,
),
PromptTemplate.from_template(CODE_FIX_TEMPLATE_EN),
prompt_name=CODE_FIX,
prompt_language="en",
**ext_params,
**en_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_PERF_TEMPLATE_ZH,
),
PromptTemplate.from_template(CODE_PERF_TEMPLATE_ZH),
prompt_name=CODE_PERF,
prompt_language="zh",
**ext_params,
**zh_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_PERF_TEMPLATE_EN,
),
PromptTemplate.from_template(CODE_PERF_TEMPLATE_EN),
prompt_name=CODE_PERF,
prompt_language="en",
**ext_params,
**en_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_ZH,
),
PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_ZH),
prompt_name=CODE_EXPLAIN,
prompt_language="zh",
**ext_params,
**zh_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_EN,
),
PromptTemplate.from_template(CODE_EXPLAIN_TEMPLATE_EN),
prompt_name=CODE_EXPLAIN,
prompt_language="en",
**ext_params,
**en_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_ZH,
),
PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_ZH),
prompt_name=CODE_COMMENT,
prompt_language="zh",
**ext_params,
**zh_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_EN,
),
PromptTemplate.from_template(CODE_COMMENT_TEMPLATE_EN),
prompt_name=CODE_COMMENT,
prompt_language="en",
**ext_params,
**en_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_ZH,
),
PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_ZH),
prompt_name=CODE_TRANSLATE,
prompt_language="zh",
**ext_params,
**zh_ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_EN,
),
PromptTemplate.from_template(CODE_TRANSLATE_TEMPLATE_EN),
prompt_name=CODE_TRANSLATE,
prompt_language="en",
**ext_params,
**en_ext_params,
)
class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
class PromptTemplateBuilderOperator(MapOperator[TriggerReqBody, ChatPromptTemplate]):
"""Build prompt template for chat with code."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._default_prompt_manager = PromptManager()
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
async def map(self, input_value: TriggerReqBody) -> ChatPromptTemplate:
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
from dbgpt.serve.prompt.serve import Serve as PromptServe
@@ -256,7 +250,24 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
load_or_save_prompt_template(pm)
user_language = self.system_app.config.get_current_lang(default="en")
if not input_value.command:
# No command, just chat, not include system prompt.
default_prompt_list = pm.prefer_query(
CODE_DEFAULT, prefer_prompt_language=user_language
)
default_prompt_template = (
default_prompt_list[0].to_prompt_template().template
)
prompt = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(default_prompt_template),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{user_input}"),
]
)
return prompt
# Query prompt template from prompt manager by command name
prompt_list = pm.prefer_query(
input_value.command, prefer_prompt_language=user_language
)
@@ -264,109 +275,38 @@ class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
logger.error(error_msg)
raise ValueError(error_msg)
prompt = prompt_list[0].to_prompt_template()
if input_value.command == CODE_TRANSLATE:
format_params = {
"source_language": input_value.language,
"target_language": input_value.target_language,
}
else:
format_params = {"language": input_value.language}
prompt_template = prompt_list[0].to_prompt_template()
system_message = prompt.format(**format_params)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message),
ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages),
]
context = input_value.context.dict() if input_value.context else {}
return {
"messages": messages,
"stream": input_value.stream,
"model": input_value.model,
"context": context,
}
class MyConversationOperator(PreConversationOperator):
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage, message_storage, **kwargs)
def _get_conversion_serve(self):
from dbgpt.serve.conversation.serve import (
SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME,
return ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template(prompt_template.template),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{user_input}"),
]
)
from dbgpt.serve.conversation.serve import Serve as ConversationServe
conversation_serve: ConversationServe = self.system_app.get_component(
CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None
)
return conversation_serve
@property
def storage(self):
if self._storage:
return self._storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.conv_storage
else:
logger.info("Conversation storage not found, use InMemoryStorage default")
self._storage = InMemoryStorage()
return self._storage
@property
def message_storage(self):
if self._message_storage:
return self._message_storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.message_storage
else:
logger.info("Message storage not found, use InMemoryStorage default")
self._message_storage = InMemoryStorage()
return self._message_storage
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
def parse_prompt_args(req: TriggerReqBody) -> Dict[str, Any]:
prompt_args = {"user_input": req.messages}
if not req.command:
return prompt_args
if req.command == CODE_TRANSLATE:
prompt_args["source_language"] = req.language
prompt_args["target_language"] = req.target_language
else:
prompt_args["language"] = req.language
return prompt_args
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
def history_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
"""Mapper for history conversation.
If there are multi system messages, just keep the first system message.
"""
has_system_message = False
mapper_messages = []
for messages in messages_by_round:
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
if has_system_message:
continue
else:
mapper_messages.append(message)
has_system_message = True
else:
mapper_messages.append(message)
print("history_message_mapper start:" + "=" * 70)
print(colored(ModelMessage.get_printable_message(mapper_messages), "green"))
print("history_message_mapper end:" + "=" * 72)
return mapper_messages
async def build_model_request(
messages: List[ModelMessage], req_body: TriggerReqBody
) -> ModelRequest:
return ModelRequest.build_request(
model=req_body.model,
messages=messages,
context=req_body.context,
stream=req_body.stream,
)
with DAG("dbgpt_awel_data_analyst_assistant") as dag:
@@ -377,57 +317,59 @@ with DAG("dbgpt_awel_data_analyst_assistant") as dag:
streaming_predict_func=lambda x: x.stream,
)
copilot_task = CopilotOperator()
request_handle_task = RequestBuildOperator()
prompt_template_load_task = PromptTemplateBuilderOperator()
request_handle_task = RequestBuilderOperator()
# Pre-process conversation
pre_conversation_task = MyConversationOperator()
# Keep last k round conversation.
history_conversation_task = BufferedConversationMapperOperator(
last_k_round=5, message_mapper=history_message_mapper
# Load and store chat history
chat_history_load_task = ServePreChatHistoryLoadOperator()
last_k_round = int(os.getenv("DBGPT_AWEL_DATA_ANALYST_LAST_K_ROUND", 5))
# History transform task, here we keep last k round messages
history_transform_task = BufferedConversationMapperOperator(
last_k_round=last_k_round
)
history_prompt_build_task = HistoryDynamicPromptBuilderOperator(
history_key="chat_history"
)
# Save conversation to storage.
post_conversation_task = PostConversationOperator()
# Save streaming conversation to storage.
post_streaming_conversation_task = PostStreamingConversationOperator()
model_request_build_task = JoinOperator(build_model_request)
# Use LLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
# Use BaseLLMOperator to generate response.
llm_task = LLMOperator(task_name="llm_task")
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
openai_format_stream_task = OpenAIStreamingOutputOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)
trigger >> prompt_template_load_task >> history_prompt_build_task
(
trigger
>> copilot_task
>> request_handle_task
>> pre_conversation_task
>> history_conversation_task
>> branch_task
>> MapOperator(
lambda req: ModelRequestContext(
conv_uid=req.context.conv_uid,
stream=req.stream,
chat_mode=req.context.chat_mode,
)
)
>> chat_history_load_task
>> history_transform_task
>> history_prompt_build_task
)
trigger >> MapOperator(parse_prompt_args) >> history_prompt_build_task
history_prompt_build_task >> model_request_build_task
trigger >> model_request_build_task
model_request_build_task >> branch_task
# The branch of no streaming response.
(
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
(branch_task >> llm_task >> model_parse_task >> result_join_task)
# The branch of streaming response.
(
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
(branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task)
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:

View File

@@ -12,7 +12,7 @@
# Fist round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"model": "'"$MODEL"'",
"context": {
"conv_uid": "uuid_conv_1234"
},
@@ -22,7 +22,7 @@
# Second round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"model": "'"$MODEL"'",
"context": {
"conv_uid": "uuid_conv_1234"
},
@@ -34,7 +34,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"model": "'"$MODEL"'",
"context": {
"conv_uid": "uuid_conv_stream_1234"
},
@@ -45,7 +45,7 @@
# Second round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_history/multi_round/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "gpt-3.5-turbo",
"model": "'"$MODEL"'",
"context": {
"conv_uid": "uuid_conv_stream_1234"
},
@@ -59,19 +59,27 @@ import logging
from typing import Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import InMemoryStorage, LLMClient
from dbgpt.core import (
ChatPromptTemplate,
HumanPromptTemplate,
InMemoryStorage,
MessagesPlaceholder,
ModelMessage,
ModelRequest,
ModelRequestContext,
SystemPromptTemplate,
)
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import (
BufferedConversationMapperOperator,
ChatComposerInput,
ChatHistoryPromptComposerOperator,
LLMBranchOperator,
)
from dbgpt.model.operator import (
LLMOperator,
PostConversationOperator,
PostStreamingConversationOperator,
PreConversationOperator,
RequestBuildOperator,
OpenAIStreamingOutputOperator,
StreamingLLMOperator,
)
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
logger = logging.getLogger(__name__)
@@ -100,16 +108,15 @@ class TriggerReqBody(BaseModel):
)
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
async def build_model_request(
messages: List[ModelMessage], req_body: TriggerReqBody
) -> ModelRequest:
return ModelRequest.build_request(
model=req_body.model,
messages=messages,
context=req_body.context,
stream=req_body.stream,
)
with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
@@ -120,56 +127,53 @@ with DAG("dbgpt_awel_simple_chat_history") as multi_round_dag:
request_body=TriggerReqBody,
streaming_predict_func=lambda req: req.stream,
)
# Transform request body to model request.
request_handle_task = RequestBuildOperator()
# Pre-process conversation, use InMemoryStorage to store conversation.
pre_conversation_task = PreConversationOperator(
storage=InMemoryStorage(), message_storage=InMemoryStorage()
prompt = ChatPromptTemplate(
messages=[
SystemPromptTemplate.from_template("You are a helpful chatbot."),
MessagesPlaceholder(variable_name="chat_history"),
HumanPromptTemplate.from_template("{user_input}"),
]
)
# Keep last k round conversation.
history_conversation_task = BufferedConversationMapperOperator(last_k_round=5)
# Save conversation to storage.
post_conversation_task = PostConversationOperator()
# Save streaming conversation to storage.
post_streaming_conversation_task = PostStreamingConversationOperator()
composer_operator = ChatHistoryPromptComposerOperator(
prompt_template=prompt,
last_k_round=5,
storage=InMemoryStorage(),
message_storage=InMemoryStorage(),
)
# Use LLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
# Use BaseLLMOperator to generate response.
llm_task = LLMOperator(task_name="llm_task")
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
openai_format_stream_task = OpenAIStreamingOutputOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)
(
trigger
>> request_handle_task
>> pre_conversation_task
>> history_conversation_task
>> branch_task
req_handle_task = MapOperator(
lambda req: ChatComposerInput(
context=ModelRequestContext(
conv_uid=req.context.conv_uid, stream=req.stream
),
prompt_dict={"user_input": req.messages},
model_dict={
"model": req.model,
"context": req.context,
"stream": req.stream,
},
)
)
trigger >> req_handle_task >> composer_operator >> branch_task
# The branch of no streaming response.
(
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
branch_task >> llm_task >> model_parse_task >> result_join_task
# The branch of streaming response.
(
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
branch_task >> streaming_llm_task >> openai_format_stream_task >> result_join_task
if __name__ == "__main__":
if multi_round_dag.leaf_nodes[0].dev_mode:

View File

@@ -31,3 +31,11 @@ with DAG("simple_dag_example") as dag:
trigger = HttpTrigger("/examples/hello", request_body=TriggerReqBody)
map_node = RequestHandleOperator()
trigger >> map_node
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag])
else:
pass

View File

@@ -8,9 +8,10 @@
.. code-block:: shell
DBGPT_SERVER="http://127.0.0.1:5555"
MODEL="gpt-3.5-turbo"
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"model": "'"$MODEL"'",
"messages": "hello"
}'
@@ -19,7 +20,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/chat/completions \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"model": "'"$MODEL"'",
"messages": "hello",
"stream": true
}'
@@ -29,7 +30,7 @@
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/simple_client/count_token \
-H "Content-Type: application/json" -d '{
"model": "proxyllm",
"model": "'"$MODEL"'",
"messages": "hello"
}'
@@ -40,13 +41,13 @@ from typing import Any, Dict, List, Optional, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import LLMClient
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import (
LLMBranchOperator,
from dbgpt.core.operator import LLMBranchOperator, RequestBuilderOperator
from dbgpt.model.operator import (
LLMOperator,
RequestBuildOperator,
MixinLLMOperator,
OpenAIStreamingOutputOperator,
StreamingLLMOperator,
)
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
logger = logging.getLogger(__name__)
@@ -59,18 +60,6 @@ class TriggerReqBody(BaseModel):
stream: Optional[bool] = Field(default=False, description="Whether return stream")
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
class MyModelToolOperator(
MixinLLMOperator, MapOperator[TriggerReqBody, Dict[str, Any]]
):
@@ -97,14 +86,14 @@ with DAG("dbgpt_awel_simple_llm_client_generate") as client_generate_dag:
request_body=TriggerReqBody,
streaming_predict_func=lambda req: req.stream,
)
request_handle_task = RequestBuildOperator()
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
request_handle_task = RequestBuilderOperator()
llm_task = LLMOperator(task_name="llm_task")
streaming_llm_task = StreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
openai_format_stream_task = OpenAIStreamingOutputOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)