mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-17 23:18:20 +00:00
feat(core): APP use new SDK component (#1050)
This commit is contained in:
@@ -20,8 +20,9 @@ class PromptTemplateRegistry:
|
||||
self,
|
||||
prompt_template,
|
||||
language: str = "en",
|
||||
is_default=False,
|
||||
is_default: bool = False,
|
||||
model_names: List[str] = None,
|
||||
scene_name: str = None,
|
||||
) -> None:
|
||||
"""Register prompt template with scene name, language
|
||||
registry dict format:
|
||||
@@ -37,7 +38,8 @@ class PromptTemplateRegistry:
|
||||
}
|
||||
}
|
||||
"""
|
||||
scene_name = prompt_template.template_scene
|
||||
if not scene_name:
|
||||
scene_name = prompt_template.template_scene
|
||||
if not scene_name:
|
||||
raise ValueError("Prompt template scene name cannot be empty")
|
||||
if not model_names:
|
||||
|
@@ -45,6 +45,18 @@ class BaseMessage(BaseModel, ABC):
|
||||
"round_index": self.round_index,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["BaseMessage"]) -> str:
|
||||
"""Convert messages to str
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages
|
||||
|
||||
Returns:
|
||||
str: The str messages
|
||||
"""
|
||||
return _messages_to_str(messages)
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
@@ -251,6 +263,41 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _messages_to_str(
|
||||
messages: List[BaseMessage],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
) -> str:
|
||||
"""Convert messages to str
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages
|
||||
human_prefix (str): The human prefix
|
||||
ai_prefix (str): The ai prefix
|
||||
system_prefix (str): The system prefix
|
||||
|
||||
Returns:
|
||||
str: The str messages
|
||||
"""
|
||||
str_messages = []
|
||||
for message in messages:
|
||||
role = None
|
||||
if isinstance(message, HumanMessage):
|
||||
role = human_prefix
|
||||
elif isinstance(message, AIMessage):
|
||||
role = ai_prefix
|
||||
elif isinstance(message, SystemMessage):
|
||||
role = system_prefix
|
||||
elif isinstance(message, ViewMessage):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Got unsupported message type: {message}")
|
||||
if role:
|
||||
str_messages.append(f"{role}: {message.content}")
|
||||
return "\n".join(str_messages)
|
||||
|
||||
|
||||
def _message_from_dict(message: Dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
@@ -382,6 +429,9 @@ class OnceConversation:
|
||||
self._message_index += 1
|
||||
message.index = index
|
||||
message.round_index = self.chat_order
|
||||
message.additional_kwargs["param_type"] = self.param_type
|
||||
message.additional_kwargs["param_value"] = self.param_value
|
||||
message.additional_kwargs["model_name"] = self.model_name
|
||||
self.messages.append(message)
|
||||
|
||||
def start_new_round(self) -> None:
|
||||
@@ -504,9 +554,12 @@ class OnceConversation:
|
||||
self.messages = conversation.messages
|
||||
self.start_date = conversation.start_date
|
||||
self.chat_order = conversation.chat_order
|
||||
self.model_name = conversation.model_name
|
||||
self.param_type = conversation.param_type
|
||||
self.param_value = conversation.param_value
|
||||
if not self.model_name and conversation.model_name:
|
||||
self.model_name = conversation.model_name
|
||||
if not self.param_type and conversation.param_type:
|
||||
self.param_type = conversation.param_type
|
||||
if not self.param_value and conversation.param_value:
|
||||
self.param_value = conversation.param_value
|
||||
self.cost = conversation.cost
|
||||
self.tokens = conversation.tokens
|
||||
self.user_name = conversation.user_name
|
||||
@@ -801,6 +854,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
save_message_independent: Optional[bool] = True,
|
||||
conv_storage: StorageInterface = None,
|
||||
message_storage: StorageInterface = None,
|
||||
load_message: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
|
||||
@@ -811,6 +865,8 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
self._has_stored_message_index = (
|
||||
len(kwargs["messages"]) - 1 if "messages" in kwargs else -1
|
||||
)
|
||||
# Whether to load the message from the storage
|
||||
self._load_message = load_message
|
||||
self.save_message_independent = save_message_independent
|
||||
self._id = ConversationIdentifier(conv_uid)
|
||||
if conv_storage is None:
|
||||
@@ -853,7 +909,9 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
]
|
||||
messages_to_save = message_list[self._has_stored_message_index + 1 :]
|
||||
self._has_stored_message_index = len(message_list) - 1
|
||||
self.message_storage.save_list(messages_to_save)
|
||||
if self.save_message_independent:
|
||||
# Save messages independently
|
||||
self.message_storage.save_list(messages_to_save)
|
||||
# Save conversation
|
||||
self.conv_storage.save_or_update(self)
|
||||
|
||||
@@ -876,23 +934,71 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
return
|
||||
message_ids = conversation._message_ids or []
|
||||
|
||||
# Load messages
|
||||
message_list = message_storage.load_list(
|
||||
[
|
||||
MessageIdentifier.from_str_identifier(message_id)
|
||||
for message_id in message_ids
|
||||
],
|
||||
MessageStorageItem,
|
||||
)
|
||||
messages = [message.to_message() for message in message_list]
|
||||
conversation.messages = messages
|
||||
if self._load_message:
|
||||
# Load messages
|
||||
message_list = message_storage.load_list(
|
||||
[
|
||||
MessageIdentifier.from_str_identifier(message_id)
|
||||
for message_id in message_ids
|
||||
],
|
||||
MessageStorageItem,
|
||||
)
|
||||
messages = [message.to_message() for message in message_list]
|
||||
else:
|
||||
messages = []
|
||||
real_messages = messages or conversation.messages
|
||||
conversation.messages = real_messages
|
||||
# This index is used to save the message to the storage(Has not been saved)
|
||||
# The new message append to the messages, so the index is len(messages)
|
||||
conversation._message_index = len(messages)
|
||||
conversation._message_index = len(real_messages)
|
||||
conversation.chat_order = (
|
||||
max(m.round_index for m in real_messages) if real_messages else 0
|
||||
)
|
||||
self._append_additional_kwargs(conversation, real_messages)
|
||||
self._message_ids = message_ids
|
||||
self._has_stored_message_index = len(messages) - 1
|
||||
self._has_stored_message_index = len(real_messages) - 1
|
||||
self.save_message_independent = conversation.save_message_independent
|
||||
self.from_conversation(conversation)
|
||||
|
||||
def _append_additional_kwargs(
|
||||
self, conversation: StorageConversation, messages: List[BaseMessage]
|
||||
) -> None:
|
||||
"""Parse the additional kwargs and append to the conversation
|
||||
|
||||
Args:
|
||||
conversation (StorageConversation): The conversation
|
||||
messages (List[BaseMessage]): The messages
|
||||
"""
|
||||
param_type = None
|
||||
param_value = None
|
||||
for message in messages[::-1]:
|
||||
if message.additional_kwargs:
|
||||
param_type = message.additional_kwargs.get("param_type")
|
||||
param_value = message.additional_kwargs.get("param_value")
|
||||
break
|
||||
if not conversation.param_type:
|
||||
conversation.param_type = param_type
|
||||
if not conversation.param_value:
|
||||
conversation.param_value = param_value
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete all the messages and conversation from the storage"""
|
||||
# Delete messages first
|
||||
message_list = self._get_message_items()
|
||||
message_ids = [message.identifier for message in message_list]
|
||||
self.message_storage.delete_list(message_ids)
|
||||
# Delete conversation
|
||||
self.conv_storage.delete(self.identifier)
|
||||
# Overwrite the current conversation with empty conversation
|
||||
self.from_conversation(
|
||||
StorageConversation(
|
||||
self.conv_uid,
|
||||
save_message_independent=self.save_message_independent,
|
||||
conv_storage=self.conv_storage,
|
||||
message_storage=self.message_storage,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _conversation_to_dict(once: OnceConversation) -> Dict:
|
||||
start_str: str = ""
|
||||
@@ -937,3 +1043,61 @@ def _conversation_from_dict(once: dict) -> OnceConversation:
|
||||
print(once.get("messages"))
|
||||
conversation.messages = _messages_from_dict(once.get("messages", []))
|
||||
return conversation
|
||||
|
||||
|
||||
def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessage]]:
|
||||
"""Split the messages by round index.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages.
|
||||
|
||||
Returns:
|
||||
List[List[BaseMessage]]: The messages split by round.
|
||||
"""
|
||||
messages_by_round: List[List[BaseMessage]] = []
|
||||
last_round_index = 0
|
||||
for message in messages:
|
||||
if not message.round_index:
|
||||
# Round index must bigger than 0
|
||||
raise ValueError("Message round_index is not set")
|
||||
if message.round_index > last_round_index:
|
||||
last_round_index = message.round_index
|
||||
messages_by_round.append([])
|
||||
messages_by_round[-1].append(message)
|
||||
return messages_by_round
|
||||
|
||||
|
||||
def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Append the view message to the messages
|
||||
Just for show in DB-GPT-Web.
|
||||
If already have view message, do nothing.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages with view message
|
||||
"""
|
||||
messages_by_round = _split_messages_by_round(messages)
|
||||
for current_round in messages_by_round:
|
||||
ai_message = None
|
||||
view_message = None
|
||||
for message in current_round:
|
||||
if message.type == "ai":
|
||||
ai_message = message
|
||||
elif message.type == "view":
|
||||
view_message = message
|
||||
if view_message:
|
||||
# Already have view message, do nothing
|
||||
continue
|
||||
if ai_message:
|
||||
view_message = ViewMessage(
|
||||
content=ai_message.content,
|
||||
index=ai_message.index,
|
||||
round_index=ai_message.round_index,
|
||||
additional_kwargs=ai_message.additional_kwargs.copy()
|
||||
if ai_message.additional_kwargs
|
||||
else {},
|
||||
)
|
||||
current_round.append(view_message)
|
||||
return sum(messages_by_round, [])
|
||||
|
@@ -45,7 +45,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
self,
|
||||
prompt_template: ChatPromptTemplate,
|
||||
history_key: str = "chat_history",
|
||||
last_k_round: int = 2,
|
||||
keep_start_rounds: Optional[int] = None,
|
||||
keep_end_rounds: Optional[int] = None,
|
||||
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
|
||||
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
|
||||
**kwargs,
|
||||
@@ -53,7 +54,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
super().__init__(**kwargs)
|
||||
self._prompt_template = prompt_template
|
||||
self._history_key = history_key
|
||||
self._last_k_round = last_k_round
|
||||
self._keep_start_rounds = keep_start_rounds
|
||||
self._keep_end_rounds = keep_end_rounds
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
self._sub_compose_dag = self._build_composer_dag()
|
||||
@@ -74,7 +76,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
|
||||
)
|
||||
# History transform task, here we keep last 5 round messages
|
||||
history_transform_task = BufferedConversationMapperOperator(
|
||||
last_k_round=self._last_k_round
|
||||
keep_start_rounds=self._keep_start_rounds,
|
||||
keep_end_rounds=self._keep_end_rounds,
|
||||
)
|
||||
history_prompt_build_task = HistoryPromptBuilderOperator(
|
||||
prompt=self._prompt_template, history_key=self._history_key
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from dbgpt.core import (
|
||||
LLMClient,
|
||||
MessageStorageItem,
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
@@ -11,7 +12,12 @@ from dbgpt.core import (
|
||||
StorageInterface,
|
||||
)
|
||||
from dbgpt.core.awel import BaseOperator, MapOperator
|
||||
from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
|
||||
from dbgpt.core.interface.message import (
|
||||
BaseMessage,
|
||||
_messages_to_str,
|
||||
_MultiRoundMessageMapper,
|
||||
_split_messages_by_round,
|
||||
)
|
||||
|
||||
|
||||
class BaseConversationOperator(BaseOperator, ABC):
|
||||
@@ -31,7 +37,6 @@ class BaseConversationOperator(BaseOperator, ABC):
|
||||
**kwargs,
|
||||
):
|
||||
self._check_storage = check_storage
|
||||
super().__init__(**kwargs)
|
||||
self._storage = storage
|
||||
self._message_storage = message_storage
|
||||
|
||||
@@ -167,12 +172,10 @@ class ConversationMapperOperator(
|
||||
self._message_mapper = message_mapper
|
||||
|
||||
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
|
||||
return self.map_messages(input_value)
|
||||
return await self.map_messages(input_value)
|
||||
|
||||
def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
|
||||
messages
|
||||
)
|
||||
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
|
||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
@@ -233,93 +236,66 @@ class ConversationMapperOperator(
|
||||
Args:
|
||||
"""
|
||||
# Just merge and return
|
||||
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
|
||||
return sum(messages_by_round, [])
|
||||
|
||||
def _split_messages_by_round(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[List[BaseMessage]]:
|
||||
"""Split the messages by round index.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages.
|
||||
|
||||
Returns:
|
||||
List[List[BaseMessage]]: The messages split by round.
|
||||
"""
|
||||
messages_by_round: List[List[BaseMessage]] = []
|
||||
last_round_index = 0
|
||||
for message in messages:
|
||||
if not message.round_index:
|
||||
# Round index must bigger than 0
|
||||
raise ValueError("Message round_index is not set")
|
||||
if message.round_index > last_round_index:
|
||||
last_round_index = message.round_index
|
||||
messages_by_round.append([])
|
||||
messages_by_round[-1].append(message)
|
||||
return messages_by_round
|
||||
return _merge_multi_round_messages(messages_by_round)
|
||||
|
||||
|
||||
class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""The buffered conversation mapper operator.
|
||||
"""
|
||||
The buffered conversation mapper operator which can be configured to keep
|
||||
a certain number of starting and/or ending rounds of a conversation.
|
||||
|
||||
This Operator must be used after the PreChatHistoryLoadOperator,
|
||||
and it will map the messages in the storage conversation.
|
||||
Args:
|
||||
keep_start_rounds (Optional[int]): Number of initial rounds to keep.
|
||||
keep_end_rounds (Optional[int]): Number of final rounds to keep.
|
||||
|
||||
Examples:
|
||||
# Keeping the first 2 and the last 1 rounds of a conversation
|
||||
import asyncio
|
||||
from dbgpt.core.interface.message import AIMessage, HumanMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
Transform no history messages
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
|
||||
# No history
|
||||
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
assert operator.map_messages(messages) == [
|
||||
ModelMessage(role="human", content="Hello", round_index=1)
|
||||
]
|
||||
|
||||
Transform with history messages
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# With history
|
||||
messages = [
|
||||
ModelMessage(role="human", content="Hi", round_index=1),
|
||||
ModelMessage(role="ai", content="Hello!", round_index=1),
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
operator = BufferedConversationMapperOperator(last_k_round=1)
|
||||
# Just keep the last one round, so the first round messages will be removed
|
||||
# Note: The round index 3 is not a complete round
|
||||
assert operator.map_messages(messages) == [
|
||||
ModelMessage(role="system", content="Error 404", round_index=2),
|
||||
ModelMessage(role="human", content="What's the error?", round_index=2),
|
||||
ModelMessage(role="ai", content="Just a joke.", round_index=2),
|
||||
ModelMessage(role="human", content="Funny!", round_index=3),
|
||||
]
|
||||
operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1)
|
||||
messages = [
|
||||
# Assume each HumanMessage and AIMessage belongs to separate rounds
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
# This will keep rounds 1, 2, and 3
|
||||
assert asyncio.run(operator.map_messages(messages)) == [
|
||||
HumanMessage(content="Hi", round_index=1),
|
||||
AIMessage(content="Hello!", round_index=1),
|
||||
HumanMessage(content="How are you?", round_index=2),
|
||||
AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
HumanMessage(content="What's new today?", round_index=3),
|
||||
AIMessage(content="Lots of things!", round_index=3),
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
last_k_round: Optional[int] = 2,
|
||||
keep_start_rounds: Optional[int] = None,
|
||||
keep_end_rounds: Optional[int] = None,
|
||||
message_mapper: _MultiRoundMessageMapper = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._last_k_round = last_k_round
|
||||
# Validate the input parameters
|
||||
if keep_start_rounds is not None and keep_start_rounds < 0:
|
||||
raise ValueError("keep_start_rounds must be non-negative")
|
||||
if keep_end_rounds is not None and keep_end_rounds < 0:
|
||||
raise ValueError("keep_end_rounds must be non-negative")
|
||||
|
||||
self._keep_start_rounds = keep_start_rounds
|
||||
self._keep_end_rounds = keep_end_rounds
|
||||
if message_mapper:
|
||||
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[BaseMessage]],
|
||||
) -> List[BaseMessage]:
|
||||
# Apply keep k round messages first, then apply the custom message mapper
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
messages_by_round = self._filter_round_messages(messages_by_round)
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
else:
|
||||
@@ -327,21 +303,189 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
def new_message_mapper(
|
||||
messages_by_round: List[List[BaseMessage]],
|
||||
) -> List[BaseMessage]:
|
||||
messages_by_round = self._keep_last_round_messages(messages_by_round)
|
||||
return sum(messages_by_round, [])
|
||||
messages_by_round = self._filter_round_messages(messages_by_round)
|
||||
return _merge_multi_round_messages(messages_by_round)
|
||||
|
||||
super().__init__(new_message_mapper, **kwargs)
|
||||
|
||||
def _keep_last_round_messages(
|
||||
def _filter_round_messages(
|
||||
self, messages_by_round: List[List[BaseMessage]]
|
||||
) -> List[List[BaseMessage]]:
|
||||
"""Keep the last k round messages.
|
||||
"""Filters the messages to keep only the specified starting and/or ending rounds.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from dbgpt.core import AIMessage, HumanMessage
|
||||
>>> from dbgpt.core.operator import BufferedConversationMapperOperator
|
||||
>>> messages = [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's new today?", round_index=3),
|
||||
... AIMessage(content="Lots of things!", round_index=3),
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping only the first 2 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping only the last 2 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's new today?", round_index=3),
|
||||
... AIMessage(content="Lots of things!", round_index=3),
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test keeping the first 2 and last 1 rounds
|
||||
>>> operator = BufferedConversationMapperOperator(
|
||||
... keep_start_rounds=2, keep_end_rounds=1
|
||||
... )
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's new today?", round_index=3),
|
||||
... AIMessage(content="Lots of things!", round_index=3),
|
||||
... ],
|
||||
... ]
|
||||
|
||||
# Test without specifying start or end rounds (keep all rounds)
|
||||
>>> operator = BufferedConversationMapperOperator()
|
||||
>>> assert operator._filter_round_messages(messages) == [
|
||||
... [
|
||||
... HumanMessage(content="Hi", round_index=1),
|
||||
... AIMessage(content="Hello!", round_index=1),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="How are you?", round_index=2),
|
||||
... AIMessage(content="I'm good, thanks!", round_index=2),
|
||||
... ],
|
||||
... [
|
||||
... HumanMessage(content="What's new today?", round_index=3),
|
||||
... AIMessage(content="Lots of things!", round_index=3),
|
||||
... ],
|
||||
... ]
|
||||
|
||||
Args:
|
||||
messages_by_round (List[List[BaseMessage]]): The messages grouped by round.
|
||||
|
||||
Returns:
|
||||
List[List[BaseMessage]]: Filtered list of messages.
|
||||
"""
|
||||
total_rounds = len(messages_by_round)
|
||||
if self._keep_start_rounds is not None and self._keep_end_rounds is not None:
|
||||
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
|
||||
# Avoid overlapping when the sum of start and end rounds exceeds total rounds
|
||||
return messages_by_round
|
||||
return (
|
||||
messages_by_round[: self._keep_start_rounds]
|
||||
+ messages_by_round[-self._keep_end_rounds :]
|
||||
)
|
||||
elif self._keep_start_rounds is not None:
|
||||
return messages_by_round[: self._keep_start_rounds]
|
||||
elif self._keep_end_rounds is not None:
|
||||
return messages_by_round[-self._keep_end_rounds :]
|
||||
else:
|
||||
return messages_by_round
|
||||
|
||||
|
||||
EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]
|
||||
|
||||
|
||||
class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
|
||||
"""The token buffered conversation mapper operator.
|
||||
|
||||
If the token count of the messages is greater than the max token limit, we will evict the messages by round.
|
||||
|
||||
Args:
|
||||
model (str): The model name.
|
||||
llm_client (LLMClient): The LLM client.
|
||||
max_token_limit (int): The max token limit.
|
||||
eviction_policy (EvictionPolicyType): The eviction policy.
|
||||
message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
llm_client: LLMClient,
|
||||
max_token_limit: int = 2000,
|
||||
eviction_policy: EvictionPolicyType = None,
|
||||
message_mapper: _MultiRoundMessageMapper = None,
|
||||
**kwargs,
|
||||
):
|
||||
if max_token_limit < 0:
|
||||
raise ValueError("Max token limit can't be negative")
|
||||
self._model = model
|
||||
self._llm_client = llm_client
|
||||
self._max_token_limit = max_token_limit
|
||||
self._eviction_policy = eviction_policy
|
||||
self._message_mapper = message_mapper
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
eviction_policy = self._eviction_policy or self.eviction_policy
|
||||
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
|
||||
messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round))
|
||||
# Fist time, we count the token of the messages
|
||||
current_tokens = await self._llm_client.count_token(self._model, messages_str)
|
||||
|
||||
while current_tokens > self._max_token_limit:
|
||||
# Evict the messages by round after all tokens are not greater than the max token limit
|
||||
# TODO: We should find a high performance way to do this
|
||||
messages_by_round = eviction_policy(messages_by_round)
|
||||
messages_str = _messages_to_str(
|
||||
_merge_multi_round_messages(messages_by_round)
|
||||
)
|
||||
current_tokens = await self._llm_client.count_token(
|
||||
self._model, messages_str
|
||||
)
|
||||
message_mapper = self._message_mapper or self.map_multi_round_messages
|
||||
return message_mapper(messages_by_round)
|
||||
|
||||
def eviction_policy(
|
||||
self, messages_by_round: List[List[BaseMessage]]
|
||||
) -> List[List[BaseMessage]]:
|
||||
"""Evict the messages by round, default is FIFO.
|
||||
|
||||
Args:
|
||||
messages_by_round (List[List[BaseMessage]]): The messages by round.
|
||||
|
||||
Returns:
|
||||
List[List[BaseMessage]]: The latest round messages.
|
||||
List[List[BaseMessage]]: The evicted messages by round.
|
||||
"""
|
||||
index = self._last_k_round + 1
|
||||
return messages_by_round[-index:]
|
||||
messages_by_round.pop(0)
|
||||
return messages_by_round
|
||||
|
||||
|
||||
def _merge_multi_round_messages(messages: List[List[BaseMessage]]) -> List[BaseMessage]:
|
||||
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
|
||||
return sum(messages, [])
|
||||
|
@@ -216,18 +216,27 @@ class HistoryPromptBuilderOperator(
|
||||
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
|
||||
):
|
||||
def __init__(
|
||||
self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
|
||||
self,
|
||||
prompt: ChatPromptTemplate,
|
||||
history_key: Optional[str] = None,
|
||||
check_storage: bool = True,
|
||||
str_history: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self._prompt = prompt
|
||||
self._history_key = history_key
|
||||
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
async def merge_history(
|
||||
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
|
||||
) -> List[ModelMessage]:
|
||||
prompt_dict[self._history_key] = history
|
||||
if self._str_history:
|
||||
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
|
||||
else:
|
||||
prompt_dict[self._history_key] = history
|
||||
return await self.format_prompt(self._prompt, prompt_dict)
|
||||
|
||||
|
||||
@@ -239,9 +248,16 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
The prompt template is dynamic, and it created by parent operator.
|
||||
"""
|
||||
|
||||
def __init__(self, history_key: Optional[str] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
history_key: Optional[str] = None,
|
||||
check_storage: bool = True,
|
||||
str_history: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self._history_key = history_key
|
||||
|
||||
self._str_history = str_history
|
||||
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
|
||||
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
|
||||
|
||||
@rearrange_args_by_type
|
||||
@@ -251,5 +267,8 @@ class HistoryDynamicPromptBuilderOperator(
|
||||
history: List[BaseMessage],
|
||||
prompt_dict: Dict[str, Any],
|
||||
) -> List[ModelMessage]:
|
||||
prompt_dict[self._history_key] = history
|
||||
if self._str_history:
|
||||
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
|
||||
else:
|
||||
prompt_dict[self._history_key] = history
|
||||
return await self.format_prompt(prompt, prompt_dict)
|
||||
|
@@ -49,13 +49,25 @@ class BasePromptTemplate(BaseModel):
|
||||
"""The prompt template."""
|
||||
|
||||
template_format: Optional[str] = "f-string"
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
|
||||
response_format: Optional[str] = None
|
||||
|
||||
response_key: Optional[str] = "response"
|
||||
|
||||
template_is_strict: Optional[bool] = True
|
||||
"""strict template will check template args"""
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
|
||||
self.template, **kwargs
|
||||
)
|
||||
if self.response_format:
|
||||
kwargs[self.response_key] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
@@ -75,10 +87,6 @@ class PromptTemplate(BasePromptTemplate):
|
||||
template_scene: Optional[str]
|
||||
template_define: Optional[str]
|
||||
"""this template define"""
|
||||
"""strict template will check template args"""
|
||||
template_is_strict: bool = True
|
||||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
|
||||
response_format: Optional[str]
|
||||
"""default use stream out"""
|
||||
stream_out: bool = True
|
||||
""""""
|
||||
@@ -103,17 +111,6 @@ class PromptTemplate(BasePromptTemplate):
|
||||
"""Return the prompt type key."""
|
||||
return "prompt"
|
||||
|
||||
def format(self, **kwargs: Any) -> str:
|
||||
"""Format the prompt with the inputs."""
|
||||
if self.template:
|
||||
if self.response_format:
|
||||
kwargs["response"] = json.dumps(
|
||||
self.response_format, ensure_ascii=False, indent=4
|
||||
)
|
||||
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
|
||||
self.template_is_strict
|
||||
)(self.template, **kwargs)
|
||||
|
||||
|
||||
class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
prompt: BasePromptTemplate
|
||||
@@ -129,10 +126,22 @@ class BaseChatPromptTemplate(BaseModel, ABC):
|
||||
|
||||
@classmethod
|
||||
def from_template(
|
||||
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
|
||||
cls,
|
||||
template: str,
|
||||
template_format: Optional[str] = "f-string",
|
||||
response_format: Optional[str] = None,
|
||||
response_key: Optional[str] = "response",
|
||||
template_is_strict: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> BaseChatPromptTemplate:
|
||||
"""Create a prompt template from a template string."""
|
||||
prompt = BasePromptTemplate.from_template(template, template_format)
|
||||
prompt = BasePromptTemplate.from_template(
|
||||
template,
|
||||
template_format,
|
||||
response_format=response_format,
|
||||
response_key=response_key,
|
||||
template_is_strict=template_is_strict,
|
||||
)
|
||||
return cls(prompt=prompt, **kwargs)
|
||||
|
||||
|
||||
|
@@ -284,6 +284,15 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
"""
|
||||
|
||||
def delete_list(self, resource_id: List[ResourceIdentifier]) -> None:
|
||||
"""Delete the data from the storage.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
"""
|
||||
for r in resource_id:
|
||||
self.delete(r)
|
||||
|
||||
@abstractmethod
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
"""Query data from the storage.
|
||||
|
@@ -138,14 +138,14 @@ def test_clear_messages(basic_conversation, human_message):
|
||||
def test_get_latest_user_message(basic_conversation, human_message):
|
||||
basic_conversation.add_user_message(human_message.content)
|
||||
latest_message = basic_conversation.get_latest_user_message()
|
||||
assert latest_message == human_message
|
||||
assert latest_message.content == human_message.content
|
||||
|
||||
|
||||
def test_get_system_messages(basic_conversation, system_message):
|
||||
basic_conversation.add_system_message(system_message.content)
|
||||
system_messages = basic_conversation.get_system_messages()
|
||||
assert len(system_messages) == 1
|
||||
assert system_messages[0] == system_message
|
||||
assert system_messages[0].content == system_message.content
|
||||
|
||||
|
||||
def test_from_conversation(basic_conversation):
|
||||
@@ -324,6 +324,35 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
|
||||
assert isinstance(new_conversation.messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_delete(storage_conversation, in_memory_storage):
|
||||
# Set storage
|
||||
storage_conversation.conv_storage = in_memory_storage
|
||||
storage_conversation.message_storage = in_memory_storage
|
||||
|
||||
# Add messages and save to storage
|
||||
storage_conversation.start_new_round()
|
||||
storage_conversation.add_user_message("User message")
|
||||
storage_conversation.add_ai_message("AI response")
|
||||
storage_conversation.end_current_round()
|
||||
|
||||
# Create a new StorageConversation instance to load the data
|
||||
new_conversation = StorageConversation(
|
||||
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
|
||||
)
|
||||
|
||||
# Delete the conversation
|
||||
new_conversation.delete()
|
||||
|
||||
# Check if the conversation is deleted
|
||||
assert new_conversation.conv_uid == storage_conversation.conv_uid
|
||||
assert len(new_conversation.messages) == 0
|
||||
|
||||
no_messages_conv = StorageConversation(
|
||||
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
|
||||
)
|
||||
assert len(no_messages_conv.messages) == 0
|
||||
|
||||
|
||||
def test_parse_model_messages_no_history_messages():
|
||||
messages = [
|
||||
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),
|
||||
|
@@ -14,6 +14,7 @@ from dbgpt.core.interface.operator.message_operator import (
|
||||
BufferedConversationMapperOperator,
|
||||
ConversationMapperOperator,
|
||||
PreChatHistoryLoadOperator,
|
||||
TokenBufferedConversationMapperOperator,
|
||||
)
|
||||
from dbgpt.core.interface.operator.prompt_operator import (
|
||||
DynamicPromptBuilderOperator,
|
||||
@@ -30,6 +31,7 @@ __ALL__ = [
|
||||
"BaseStreamingLLMOperator",
|
||||
"BaseConversationOperator",
|
||||
"BufferedConversationMapperOperator",
|
||||
"TokenBufferedConversationMapperOperator",
|
||||
"ConversationMapperOperator",
|
||||
"PreChatHistoryLoadOperator",
|
||||
"PromptBuilderOperator",
|
||||
|
Reference in New Issue
Block a user