mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-09 04:49:26 +00:00
feat(core): Support multi round conversation operator (#986)
This commit is contained in:
@@ -114,6 +114,9 @@ class ModelRequestContext:
|
||||
span_id: Optional[str] = None
|
||||
"""The span id of the model inference."""
|
||||
|
||||
chat_mode: Optional[str] = None
|
||||
"""The chat mode of the model inference."""
|
||||
|
||||
extra: Optional[Dict[str, Any]] = field(default_factory=dict)
|
||||
"""The extra information of the model inference."""
|
||||
|
||||
@@ -195,7 +198,13 @@ class ModelRequest:
|
||||
# Skip None fields
|
||||
return {k: v for k, v in asdict(new_reqeust).items() if v}
|
||||
|
||||
def _get_messages(self) -> List[ModelMessage]:
|
||||
def get_messages(self) -> List[ModelMessage]:
|
||||
"""Get the messages.
|
||||
|
||||
If the messages is not a list of ModelMessage, it will be converted to a list of ModelMessage.
|
||||
Returns:
|
||||
List[ModelMessage]: The messages.
|
||||
"""
|
||||
return list(
|
||||
map(
|
||||
lambda m: m if isinstance(m, ModelMessage) else ModelMessage(**m),
|
||||
@@ -209,7 +218,7 @@ class ModelRequest:
|
||||
Returns:
|
||||
Optional[ModelMessage]: The single user message.
|
||||
"""
|
||||
messages = self._get_messages()
|
||||
messages = self.get_messages()
|
||||
if len(messages) != 1 and messages[0].role != ModelMessageRoleType.HUMAN:
|
||||
raise ValueError("The messages is not a single user message")
|
||||
return messages[0]
|
||||
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.awel import MapOperator
|
||||
@@ -176,6 +176,22 @@ class ModelMessage(BaseModel):
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
|
||||
@staticmethod
|
||||
def get_printable_message(messages: List["ModelMessage"]) -> str:
|
||||
"""Get the printable message"""
|
||||
str_msg = ""
|
||||
for message in messages:
|
||||
curr_message = (
|
||||
f"(Round {message.round_index}) {message.role}: {message.content} "
|
||||
)
|
||||
str_msg += curr_message.rstrip() + "\n"
|
||||
|
||||
return str_msg
|
||||
|
||||
|
||||
_SingleRoundMessage = List[ModelMessage]
|
||||
_MultiRoundMessageMapper = Callable[[List[_SingleRoundMessage]], List[ModelMessage]]
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> Dict:
|
||||
return message.to_dict()
|
||||
|
@@ -5,6 +5,7 @@ from typing import Any, AsyncIterator, List, Optional
|
||||
from dbgpt.core import (
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
ModelOutput,
|
||||
ModelRequest,
|
||||
ModelRequestContext,
|
||||
@@ -12,6 +13,7 @@ from dbgpt.core import (
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator, TransformStreamAbsOperator
|
||||
from dbgpt.core.interface.message import _MultiRoundMessageMapper
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
@@ -24,7 +26,7 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._storage = storage
|
||||
@@ -88,7 +90,7 @@ class PreConversationOperator(
|
||||
self,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(storage=storage, message_storage=message_storage)
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
@@ -109,7 +111,7 @@ class PreConversationOperator(
|
||||
if not input_value.context.extra:
|
||||
input_value.context.extra = {}
|
||||
|
||||
chat_mode = input_value.context.extra.get("chat_mode")
|
||||
chat_mode = input_value.context.chat_mode
|
||||
|
||||
# Create a new storage conversation, this will load the conversation from storage, so we must do this async
|
||||
storage_conv: StorageConversation = await self.blocking_func_to_async(
|
||||
@@ -121,11 +123,8 @@ class PreConversationOperator(
|
||||
conv_storage=self.storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
# The input message must be a single user message
|
||||
single_human_message: ModelMessage = input_value.get_single_user_message()
|
||||
storage_conv.start_new_round()
|
||||
storage_conv.add_user_message(single_human_message.content)
|
||||
|
||||
input_messages = input_value.get_messages()
|
||||
await self.save_to_storage(storage_conv, input_messages)
|
||||
# Get all messages from current storage conversation, and overwrite the input value
|
||||
messages: List[ModelMessage] = storage_conv.get_model_messages()
|
||||
input_value.messages = messages
|
||||
@@ -139,6 +138,42 @@ class PreConversationOperator(
|
||||
)
|
||||
return input_value
|
||||
|
||||
async def save_to_storage(
|
||||
self, storage_conv: StorageConversation, input_messages: List[ModelMessage]
|
||||
) -> None:
|
||||
"""Save the messages to storage.
|
||||
|
||||
Args:
|
||||
storage_conv (StorageConversation): The storage conversation.
|
||||
input_messages (List[ModelMessage]): The input messages.
|
||||
"""
|
||||
# check first
|
||||
self.check_messages(input_messages)
|
||||
storage_conv.start_new_round()
|
||||
for message in input_messages:
|
||||
if message.role == ModelMessageRoleType.HUMAN:
|
||||
storage_conv.add_user_message(message.content)
|
||||
else:
|
||||
storage_conv.add_system_message(message.content)
|
||||
|
||||
def check_messages(self, messages: List[ModelMessage]) -> None:
|
||||
"""Check the messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages is empty.
|
||||
"""
|
||||
if not messages:
|
||||
raise ValueError("Input messages is empty")
|
||||
for message in messages:
|
||||
if message.role not in [
|
||||
ModelMessageRoleType.HUMAN,
|
||||
ModelMessageRoleType.SYSTEM,
|
||||
]:
|
||||
raise ValueError(f"Message role {message.role} is not supported")
|
||||
|
||||
async def after_dag_end(self):
|
||||
"""The callback after DAG end"""
|
||||
# Save the storage conversation to storage after the whole DAG finished
|
||||
@@ -198,8 +233,9 @@ class PostStreamingConversationOperator(
|
||||
class ConversationMapperOperator(
|
||||
BaseConversationOperator, MapOperator[ModelRequest, ModelRequest]
|
||||
):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, message_mapper: _MultiRoundMessageMapper = None, **kwargs):
|
||||
MapOperator.__init__(self, **kwargs)
|
||||
self._message_mapper = message_mapper
|
||||
|
||||
async def map(self, input_value: ModelRequest) -> ModelRequest:
|
||||
"""Map the input value to a ModelRequest.
|
||||
@@ -211,12 +247,12 @@ class ConversationMapperOperator(
|
||||
ModelRequest: The mapped ModelRequest.
|
||||
"""
|
||||
input_value = input_value.copy()
|
||||
messages: List[ModelMessage] = await self.map_messages(input_value.messages)
|
||||
messages: List[ModelMessage] = self.map_messages(input_value.messages)
|
||||
# Overwrite the input value
|
||||
input_value.messages = messages
|
||||
return input_value
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
|
||||
Args:
|
||||
@@ -225,7 +261,73 @@ class ConversationMapperOperator(
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
"""
|
||||
return messages
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
def map_multi_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[ModelMessage]:
|
||||
"""Map multi round messages to a list of ModelMessage
|
||||
|
||||
By default, just merge all multi round messages to a list of ModelMessage according origin order.
|
||||
And you can overwrite this method to implement your own logic.
|
||||
|
||||
Examples:
|
||||
|
||||
Merge multi round messages to a list of ModelMessage according origin order.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core.operator import ConversationMapperOperator
|
||||
messages_by_round = [
|
||||
[
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
],
|
||||
[
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
],
|
||||
[
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
],
|
||||
]
|
||||
operator = ConversationMapperOperator()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
|
||||
Map multi round messages to a list of ModelMessage just keep the last one round.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MyMapper(ConversationMapperOperator):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def map_multi_round_messages(self, messages_by_round: List[List[ModelMessage]]) -> List[ModelMessage]:
|
||||
return messages_by_round[-1]
|
||||
operator = MyMapper()
|
||||
messages = operator.map_multi_round_messages(messages_by_round)
|
||||
assert messages == [
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
|
||||
Args:
|
||||
"""
|
||||
# Just merge and return
|
||||
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
def _split_messages_by_round(
|
||||
self, messages: List[ModelMessage]
|
||||
@@ -236,7 +338,7 @@ class ConversationMapperOperator(
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
|
||||
Returns:
|
||||
List[List[ModelMessage]]: The splitted messages.
|
||||
List[List[ModelMessage]]: The split messages.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = []
|
||||
last_round_index = 0
|
||||
@@ -263,15 +365,13 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
assert messages == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
assert operator.map_messages(messages) == [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
@@ -287,10 +387,9 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
messages = asyncio.run(operator.map_messages(messages))
|
||||
# Just keep the last one round, so the first round messages will be removed
|
||||
# Note: The round index 3 is not a complete round
|
||||
assert messages == [
|
||||
assert operator.map_messages(messages) == [
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
@@ -298,24 +397,42 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self, last_k_round: Optional[int] = 2, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
last_k_round: Optional[int] = 2,
|
||||
message_mapper: _MultiRoundMessageMapper = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._last_k_round = last_k_round
|
||||
if message_mapper:
|
||||
|
||||
async def map_messages(self, messages: List[ModelMessage]) -> List[ModelMessage]:
|
||||
"""Map the input messages to a list of ModelMessage.
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
# Apply keep k round messages first, then apply the custom message mapper
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
else:
|
||||
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[ModelMessage]],
|
||||
) -> List[ModelMessage]:
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
super().__init__(new_message_mapper, **kwargs)
|
||||
|
||||
def _keep_last_round_messages(
|
||||
self, messages_by_round: List[List[ModelMessage]]
|
||||
) -> List[List[ModelMessage]]:
|
||||
"""Keep the last k round messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The input messages.
|
||||
messages_by_round (List[List[ModelMessage]]): The messages by round.
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The mapped ModelMessage.
|
||||
List[List[ModelMessage]]: The latest round messages.
|
||||
"""
|
||||
messages_by_round: List[List[ModelMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
# Get the last k round messages
|
||||
index = self._last_k_round + 1
|
||||
messages_by_round = messages_by_round[-index:]
|
||||
messages: List[ModelMessage] = sum(messages_by_round, [])
|
||||
return messages
|
||||
return messages_by_round[-index:]
|
||||
|
@@ -169,9 +169,7 @@ class StoragePromptTemplate(StorageItem):
|
||||
def to_prompt_template(self) -> PromptTemplate:
|
||||
"""Convert the storage prompt template to a prompt template."""
|
||||
input_variables = (
|
||||
None
|
||||
if not self.input_variables
|
||||
else self.input_variables.strip().split(",")
|
||||
[] if not self.input_variables else self.input_variables.strip().split(",")
|
||||
)
|
||||
return PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
@@ -458,6 +456,33 @@ class PromptManager:
|
||||
)
|
||||
self.storage.save(storage_prompt_template)
|
||||
|
||||
def query_or_save(
|
||||
self, prompt_template: PromptTemplate, prompt_name: str, **kwargs
|
||||
) -> StoragePromptTemplate:
|
||||
"""Query a prompt template from storage, if not found, save it.
|
||||
|
||||
Args:
|
||||
prompt_template (PromptTemplate): The prompt template to save.
|
||||
prompt_name (str): The name of the prompt template.
|
||||
kwargs (Dict): Other params to build the storage prompt template.
|
||||
More details in :meth:`~StoragePromptTemplate.from_prompt_template`.
|
||||
|
||||
Returns:
|
||||
StoragePromptTemplate: The storage prompt template.
|
||||
"""
|
||||
storage_prompt_template = StoragePromptTemplate.from_prompt_template(
|
||||
prompt_template, prompt_name, **kwargs
|
||||
)
|
||||
exist_prompt_template = self.storage.load(
|
||||
storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
if exist_prompt_template:
|
||||
return exist_prompt_template
|
||||
self.save(prompt_template, prompt_name, **kwargs)
|
||||
return self.storage.load(
|
||||
storage_prompt_template.identifier, StoragePromptTemplate
|
||||
)
|
||||
|
||||
def list(self, **kwargs) -> List[StoragePromptTemplate]:
|
||||
"""List prompt templates from storage.
|
||||
|
||||
|
Reference in New Issue
Block a user