feat(core): Support multi round conversation operator (#986)

This commit is contained in:
Fangyin Cheng
2023-12-27 23:26:28 +08:00
committed by GitHub
parent 9aec636b02
commit b13d3f6d92
63 changed files with 2011 additions and 314 deletions

View File

@@ -0,0 +1,438 @@
"""AWEL: Data analyst assistant.
DB-GPT will automatically load and execute the current file after startup.
Examples:
.. code-block:: shell
# Run this file in your terminal with dev mode.
# First terminal
export OPENAI_API_KEY=xxx
export OPENAI_API_BASE=https://api.openai.com/v1
python examples/awel/simple_chat_history_example.py
Code fix command, return no streaming response
.. code-block:: shell
# Open a new terminal
# Second terminal
DBGPT_SERVER="http://127.0.0.1:5555"
MODEL="gpt-3.5-turbo"
# Fist round
curl -X POST $DBGPT_SERVER/api/v1/awel/trigger/examples/data_analyst/copilot \
-H "Content-Type: application/json" -d '{
"command": "dbgpt_awel_data_analyst_code_fix",
"model": "gpt-3.5-turbo",
"stream": false,
"context": {
"conv_uid": "uuid_conv_copilot_1234",
"chat_mode": "chat_with_code"
},
"messages": "SELECT * FRM orders WHERE order_amount > 500;"
}'
"""
import logging
from functools import cache
from typing import Any, Dict, List, Optional
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core import (
InMemoryStorage,
LLMClient,
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
PromptManager,
PromptTemplate,
StorageConversation,
StorageInterface,
)
from dbgpt.core.awel import DAG, HttpTrigger, JoinOperator, MapOperator
from dbgpt.core.operator import (
BufferedConversationMapperOperator,
LLMBranchOperator,
LLMOperator,
PostConversationOperator,
PostStreamingConversationOperator,
PreConversationOperator,
RequestBuildOperator,
StreamingLLMOperator,
)
from dbgpt.model import MixinLLMOperator, OpenAIStreamingOperator
from dbgpt.util.utils import colored
logger = logging.getLogger(__name__)
CODE_FIX = "dbgpt_awel_data_analyst_code_fix"
CODE_PERF = "dbgpt_awel_data_analyst_code_perf"
CODE_EXPLAIN = "dbgpt_awel_data_analyst_code_explain"
CODE_COMMENT = "dbgpt_awel_data_analyst_code_comment"
CODE_TRANSLATE = "dbgpt_awel_data_analyst_code_translate"
CODE_FIX_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
这里有一段 {language} 代码。请按照最佳实践检查代码,找出并修复所有错误。请给出修复后的代码,并且提供对您所做的每一行更正的逐行解释,请使用和用户相同的语言进行回答。"""
CODE_FIX_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
here is a snippet of code of {language}. Please review the code following best practices to identify and fix all errors.
Provide the corrected code and include a line-by-line explanation of all the fixes you've made, please use the same language as the user."""
CODE_PERF_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,这里有一段 {language} 代码。
请你按照最佳实践来优化这段代码。请在代码中加入注释点明所做的更改,并解释每项优化的原因,以便提高代码的维护性和性能,请使用和用户相同的语言进行回答。"""
CODE_PERF_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
you are provided with a snippet of code of {language}. Please optimize the code according to best practices.
Include comments to highlight the changes made and explain the reasons for each optimization for better maintenance and performance,
please use the same language as the user."""
CODE_EXPLAIN_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,
现在给你的是一份 {language} 代码。请你逐行解释代码的含义,请使用和用户相同的语言进行回答。"""
CODE_EXPLAIN_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
you are provided with a snippet of code of {language}. Please explain the meaning of the code line by line,
please use the same language as the user."""
CODE_COMMENT_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,现在给你的是一份 {language} 代码。
请你为每一行代码添加注释,解释每个部分的作用,请使用和用户相同的语言进行回答。"""
CODE_COMMENT_TEMPLATE_EN = """As an experienced Data Warehouse Developer and Data Analyst.
Below is a snippet of code written in {language}.
Please provide line-by-line comments explaining what each section of the code does, please use the same language as the user."""
CODE_TRANSLATE_TEMPLATE_ZH = """作为一名经验丰富的数据仓库开发者和数据分析师,现在手头有一份用{source_language}语言编写的代码片段。
请你将这段代码准确无误地翻译成{target_language}语言,确保语法和功能在翻译后的代码中得到正确体现,请使用和用户相同的语言进行回答。"""
CODE_TRANSLATE_TEMPLATE_EN = """As an experienced data warehouse developer and data analyst,
you're presented with a snippet of code written in {source_language}.
Please translate this code into {target_language} ensuring that the syntax and functionalities are accurately reflected in the translated code,
please use the same language as the user."""
class ReqContext(BaseModel):
user_name: Optional[str] = Field(
None, description="The user name of the model request."
)
sys_code: Optional[str] = Field(
None, description="The system code of the model request."
)
conv_uid: Optional[str] = Field(
None, description="The conversation uid of the model request."
)
chat_mode: Optional[str] = Field(
"chat_with_code", description="The chat mode of the model request."
)
class TriggerReqBody(BaseModel):
messages: str = Field(..., description="User input messages")
command: Optional[str] = Field(default="fix", description="Command name")
model: Optional[str] = Field(default="gpt-3.5-turbo", description="Model name")
stream: Optional[bool] = Field(default=False, description="Whether return stream")
language: Optional[str] = Field(default="hive", description="Language")
target_language: Optional[str] = Field(
default="hive", description="Target language, use in translate"
)
context: Optional[ReqContext] = Field(
default=None, description="The context of the model request."
)
@cache
def load_or_save_prompt_template(pm: PromptManager):
ext_params = {
"chat_scene": "chat_with_code",
"sub_chat_scene": "data_analyst",
"prompt_type": "common",
}
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_FIX_TEMPLATE_ZH,
),
prompt_name=CODE_FIX,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_FIX_TEMPLATE_EN,
),
prompt_name=CODE_FIX,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_PERF_TEMPLATE_ZH,
),
prompt_name=CODE_PERF,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_PERF_TEMPLATE_EN,
),
prompt_name=CODE_PERF,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_ZH,
),
prompt_name=CODE_EXPLAIN,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_EXPLAIN_TEMPLATE_EN,
),
prompt_name=CODE_EXPLAIN,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_ZH,
),
prompt_name=CODE_COMMENT,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["language"],
template=CODE_COMMENT_TEMPLATE_EN,
),
prompt_name=CODE_COMMENT,
prompt_language="en",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_ZH,
),
prompt_name=CODE_TRANSLATE,
prompt_language="zh",
**ext_params,
)
pm.query_or_save(
PromptTemplate(
input_variables=["source_language", "target_language"],
template=CODE_TRANSLATE_TEMPLATE_EN,
),
prompt_name=CODE_TRANSLATE,
prompt_language="en",
**ext_params,
)
class CopilotOperator(MapOperator[TriggerReqBody, Dict[str, Any]]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._default_prompt_manager = PromptManager()
async def map(self, input_value: TriggerReqBody) -> Dict[str, Any]:
from dbgpt.serve.prompt.serve import SERVE_APP_NAME as PROMPT_SERVE_APP_NAME
from dbgpt.serve.prompt.serve import Serve as PromptServe
prompt_serve = self.system_app.get_component(
PROMPT_SERVE_APP_NAME, PromptServe, default_component=None
)
if prompt_serve:
pm = prompt_serve.prompt_manager
else:
pm = self._default_prompt_manager
load_or_save_prompt_template(pm)
user_language = self.system_app.config.get_current_lang(default="en")
prompt_list = pm.prefer_query(
input_value.command, prefer_prompt_language=user_language
)
if not prompt_list:
error_msg = f"Prompt not found for command {input_value.command}, user_language: {user_language}"
logger.error(error_msg)
raise ValueError(error_msg)
prompt = prompt_list[0].to_prompt_template()
if input_value.command == CODE_TRANSLATE:
format_params = {
"source_language": input_value.language,
"target_language": input_value.target_language,
}
else:
format_params = {"language": input_value.language}
system_message = prompt.format(**format_params)
messages = [
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=system_message),
ModelMessage(role=ModelMessageRoleType.HUMAN, content=input_value.messages),
]
context = input_value.context.dict() if input_value.context else {}
return {
"messages": messages,
"stream": input_value.stream,
"model": input_value.model,
"context": context,
}
class MyConversationOperator(PreConversationOperator):
def __init__(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
):
super().__init__(storage, message_storage, **kwargs)
def _get_conversion_serve(self):
from dbgpt.serve.conversation.serve import (
SERVE_APP_NAME as CONVERSATION_SERVE_APP_NAME,
)
from dbgpt.serve.conversation.serve import Serve as ConversationServe
conversation_serve: ConversationServe = self.system_app.get_component(
CONVERSATION_SERVE_APP_NAME, ConversationServe, default_component=None
)
return conversation_serve
@property
def storage(self):
if self._storage:
return self._storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.conv_storage
else:
logger.info("Conversation storage not found, use InMemoryStorage default")
self._storage = InMemoryStorage()
return self._storage
@property
def message_storage(self):
if self._message_storage:
return self._message_storage
conversation_serve = self._get_conversion_serve()
if conversation_serve:
return conversation_serve.message_storage
else:
logger.info("Message storage not found, use InMemoryStorage default")
self._message_storage = InMemoryStorage()
return self._message_storage
class MyLLMOperator(MixinLLMOperator, LLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
LLMOperator.__init__(self, llm_client, **kwargs)
class MyStreamingLLMOperator(MixinLLMOperator, StreamingLLMOperator):
def __init__(self, llm_client: Optional[LLMClient] = None, **kwargs):
super().__init__(llm_client)
StreamingLLMOperator.__init__(self, llm_client, **kwargs)
def history_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
"""Mapper for history conversation.
If there are multi system messages, just keep the first system message.
"""
has_system_message = False
mapper_messages = []
for messages in messages_by_round:
for message in messages:
if message.role == ModelMessageRoleType.SYSTEM:
if has_system_message:
continue
else:
mapper_messages.append(message)
has_system_message = True
else:
mapper_messages.append(message)
print("history_message_mapper start:" + "=" * 70)
print(colored(ModelMessage.get_printable_message(mapper_messages), "green"))
print("history_message_mapper end:" + "=" * 72)
return mapper_messages
with DAG("dbgpt_awel_data_analyst_assistant") as dag:
trigger = HttpTrigger(
"/examples/data_analyst/copilot",
request_body=TriggerReqBody,
methods="POST",
streaming_predict_func=lambda x: x.stream,
)
copilot_task = CopilotOperator()
request_handle_task = RequestBuildOperator()
# Pre-process conversation
pre_conversation_task = MyConversationOperator()
# Keep last k round conversation.
history_conversation_task = BufferedConversationMapperOperator(
last_k_round=5, message_mapper=history_message_mapper
)
# Save conversation to storage.
post_conversation_task = PostConversationOperator()
# Save streaming conversation to storage.
post_streaming_conversation_task = PostStreamingConversationOperator()
# Use LLMOperator to generate response.
llm_task = MyLLMOperator(task_name="llm_task")
streaming_llm_task = MyStreamingLLMOperator(task_name="streaming_llm_task")
branch_task = LLMBranchOperator(
stream_task_name="streaming_llm_task", no_stream_task_name="llm_task"
)
model_parse_task = MapOperator(lambda out: out.to_dict())
openai_format_stream_task = OpenAIStreamingOperator()
result_join_task = JoinOperator(
combine_function=lambda not_stream_out, stream_out: not_stream_out or stream_out
)
(
trigger
>> copilot_task
>> request_handle_task
>> pre_conversation_task
>> history_conversation_task
>> branch_task
)
# The branch of no streaming response.
(
branch_task
>> llm_task
>> post_conversation_task
>> model_parse_task
>> result_join_task
)
# The branch of streaming response.
(
branch_task
>> streaming_llm_task
>> post_streaming_conversation_task
>> openai_format_stream_task
>> result_join_task
)
if __name__ == "__main__":
if dag.leaf_nodes[0].dev_mode:
from dbgpt.core.awel import setup_dev_environment
setup_dev_environment([dag])
else:
pass