feat(core): Support multi round conversation operator (#986)

This commit is contained in:
Fangyin Cheng
2023-12-27 23:26:28 +08:00
committed by GitHub
parent 9aec636b02
commit b13d3f6d92
63 changed files with 2011 additions and 314 deletions

View File

@@ -114,6 +114,9 @@ class ModelRequestContext:
span_id: Optional[str] = None
"""The span id of the model inference."""
chat_mode: Optional[str] = None
"""The chat mode of the model inference."""
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
"""The extra information of the model inference."""
@@ -195,7 +198,13 @@ class ModelRequest:
# Skip None fields
return {k: v for k, v in asdict(new_reqeust).items() if v}
def _get_messages(self) -> List[ModelMessage]:
def get_messages(self) -> List[ModelMessage]:
"""Get the messages.
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
Returns:
List[ModelMessage]: The messages.
"""
return list(
map(
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
@@ -209,7 +218,7 @@ class ModelRequest:
Returns:
Optional[ModelMessage]: The single user message.
"""
messages = self._get_messages()
messages = self.get_messages()
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
raise ValueError("The messages is not a single user message")
return messages[0]

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
@@ -176,6 +176,22 @@ class ModelMessage(BaseModel):
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
@staticmethod
def get_printable_message(messages: List["ModelMessage"]) -> str:
"""Get the printable message"""
str_msg = ""
for message in messages:
curr_message = (
f"(Round {message.round_index}) {message.role}: {message.content} "
)
str_msg += curr_message.rstrip() + "\n"
return str_msg
_SingleRoundMessage = List[ModelMessage]
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
def _message_to_dict(message: BaseMessage) -> Dict:
return message.to_dict()

View File

@@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, List, Optional
from dbgpt.core import (
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
ModelOutput,
ModelRequest,
ModelRequestContext,
@@ -12,6 +13,7 @@ from dbgpt.core import (
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
from dbgpt.core.interface.message import _MultiRoundMessageMapper
class BaseConversationOperator(BaseOperator, ABC):
@@ -24,7 +26,7 @@ class BaseConversationOperator(BaseOperator, ABC):
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
**kwargs,
):
super().__init__(**kwargs)
self._storage = storage
@@ -88,7 +90,7 @@ class PreConversationOperator(
self,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs
**kwargs,
):
super().__init__(storage=storage, message_storage=message_storage)
MapOperator.__init__(self, **kwargs)
@@ -109,7 +111,7 @@ class PreConversationOperator(
if not input_value.context.extra:
input_value.context.extra = {}
chat_mode = input_value.context.extra.get("chat_mode")
chat_mode = input_value.context.chat_mode
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
storage_conv: StorageConversation = await self.blocking_func_to_async(
@@ -121,11 +123,8 @@ class PreConversationOperator(
conv_storage=self.storage,
message_storage=self.message_storage,
)
# The input message must be a single user message
single_human_message: ModelMessage = input_value.get_single_user_message()
storage_conv.start_new_round()
storage_conv.add_user_message(single_human_message.content)
input_messages = input_value.get_messages()
await self.save_to_storage(storage_conv, input_messages)
# Get all messages from current storage conversation, and overwrite the input value
messages: List[ModelMessage] = storage_conv.get_model_messages()
input_value.messages = messages
@@ -139,6 +138,42 @@ class PreConversationOperator(
)
return input_value
async def save_to_storage(
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
) -> None:
"""Save the messages to storage.
Args:
storage_conv (StorageConversation): The storage conversation.
input_messages (List[ModelMessage]): The input messages.
"""
# check first
self.check_messages(input_messages)
storage_conv.start_new_round()
for message in input_messages:
if message.role == ModelMessageRoleType.HUMAN:
storage_conv.add_user_message(message.content)
else:
storage_conv.add_system_message(message.content)
def check_messages(self, messages: List[ModelMessage]) -> None:
"""Check the messages.
Args:
messages (List[ModelMessage]): The messages.
Raises:
ValueError: If the messages is empty.
"""
if not messages:
raise ValueError("Input messages is empty")
for message in messages:
if message.role not in [
ModelMessageRoleType.HUMAN,
ModelMessageRoleType.SYSTEM,
]:
raise ValueError(f"Message role {message.role} is not supported")
async def after_dag_end(self):
"""The callback after DAG end"""
# Save the storage conversation to storage after the whole DAG finished
@@ -198,8 +233,9 @@ class PostStreamingConversationOperator(
class ConversationMapperOperator(
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
):
def __init__(self, **kwargs):
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
MapOperator.__init__(self, **kwargs)
self._message_mapper = message_mapper
async def map(self, input_value: ModelRequest) -> ModelRequest:
"""Map the input value to a ModelRequest.
@@ -211,12 +247,12 @@ class ConversationMapperOperator(
ModelRequest: The mapped ModelRequest.
"""
input_value = input_value.copy()
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
messages: List[ModelMessage] = self.map_messages(input_value.messages)
# Overwrite the input value
input_value.messages = messages
return input_value
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
Args:
@@ -225,7 +261,73 @@ class ConversationMapperOperator(
Returns:
List[ModelMessage]: The mapped ModelMessage.
"""
return messages
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
def map_multi_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[ModelMessage]:
"""Map multi round messages to a list of ModelMessage
By default, just merge all multi round messages to a list of ModelMessage according origin order.
And you can overwrite this method to implement your own logic.
Examples:
Merge multi round messages to a list of ModelMessage according origin order.
.. code-block:: python
import asyncio
from dbgpt.core.operator import ConversationMapperOperator
messages_by_round = [
[
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
],
[
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
],
[
ModelMessage(role="human", content="Funny!", round_index=3),
],
]
operator = ConversationMapperOperator()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
Map multi round messages to a list of ModelMessage just keep the last one round.
.. code-block:: python
class MyMapper(ConversationMapperOperator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
return messages_by_round[-1]
operator = MyMapper()
messages = operator.map_multi_round_messages(messages_by_round)
assert messages == [
ModelMessage(role="human", content="Funny!", round_index=3),
]
Args:
"""
# Just merge and return
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
return sum(messages_by_round, [])
def _split_messages_by_round(
self, messages: List[ModelMessage]
@@ -236,7 +338,7 @@ class ConversationMapperOperator(
messages (List[ModelMessage]): The input messages.
Returns:
List[List[ModelMessage]]: The splitted messages.
List[List[ModelMessage]]: The split messages.
"""
messages_by_round: List[List[ModelMessage]] = []
last_round_index = 0
@@ -263,15 +365,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
.. code-block:: python
import asyncio
from dbgpt.core import ModelMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
# No history
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
Transform with history messages
@@ -287,10 +387,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
ModelMessage(role="human", content="Funny!", round_index=3),
]
operator = BufferedConversationMapperOperator(last_k_round=1)
messages = asyncio.run(operator.map_messages(messages))
# Just keep the last one round, so the first round messages will be removed
# Note: The round index 3 is not a complete round
assert messages == [
assert operator.map_messages(messages) == [
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
@@ -298,24 +397,42 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
]
"""
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
super().__init__(**kwargs)
def __init__(
self,
last_k_round: Optional[int] = 2,
message_mapper: _MultiRoundMessageMapper = None,
**kwargs,
):
self._last_k_round = last_k_round
if message_mapper:
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
"""Map the input messages to a list of ModelMessage.
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
# Apply keep k round messages first, then apply the custom message mapper
messages_by_round = self._keep_last_round_messages(messages_by_round)
return message_mapper(messages_by_round)
else:
def new_message_mapper(
messages_by_round: List[List[ModelMessage]],
) -> List[ModelMessage]:
messages_by_round = self._keep_last_round_messages(messages_by_round)
return sum(messages_by_round, [])
super().__init__(new_message_mapper, **kwargs)
def _keep_last_round_messages(
self, messages_by_round: List[List[ModelMessage]]
) -> List[List[ModelMessage]]:
"""Keep the last k round messages.
Args:
messages (List[ModelMessage]): The input messages.
messages_by_round (List[List[ModelMessage]]): The messages by round.
Returns:
List[ModelMessage]: The mapped ModelMessage.
List[List[ModelMessage]]: The latest round messages.
"""
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
messages
)
# Get the last k round messages
index = self._last_k_round + 1
messages_by_round = messages_by_round[-index:]
messages: List[ModelMessage] = sum(messages_by_round, [])
return messages
return messages_by_round[-index:]

View File

@@ -169,9 +169,7 @@ class StoragePromptTemplate(StorageItem):
def to_prompt_template(self) -> PromptTemplate:
"""Convert the storage prompt template to a prompt template."""
input_variables = (
None
if not self.input_variables
else self.input_variables.strip().split(",")
[] if not self.input_variables else self.input_variables.strip().split(",")
)
return PromptTemplate(
input_variables=input_variables,
@@ -458,6 +456,33 @@ class PromptManager:
)
self.storage.save(storage_prompt_template)
def query_or_save(
self, prompt_template: PromptTemplate, prompt_name: str, **kwargs
) -> StoragePromptTemplate:
"""Query a prompt template from storage, if not found, save it.
Args:
prompt_template (PromptTemplate): The prompt template to save.
prompt_name (str): The name of the prompt template.
kwargs (Dict): Other params to build the storage prompt template.
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
Returns:
StoragePromptTemplate: The storage prompt template.
"""
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
prompt_template, prompt_name, **kwargs
)
exist_prompt_template = self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
if exist_prompt_template:
return exist_prompt_template
self.save(prompt_template, prompt_name, **kwargs)
return self.storage.load(
storage_prompt_template.identifier, StoragePromptTemplate
)
def list(self, **kwargs) -> List[StoragePromptTemplate]:
"""List prompt templates from storage.