feat(core): APP use new SDK component (#1050)

This commit is contained in:
Fangyin Cheng
2024-01-10 10:39:04 +08:00
committed by GitHub
parent e11b72c724
commit fa8b5b190c
242 changed files with 2768 additions and 2163 deletions

View File

@@ -20,8 +20,9 @@ class PromptTemplateRegistry:
self,
prompt_template,
language: str = "en",
is_default=False,
is_default: bool = False,
model_names: List[str] = None,
scene_name: str = None,
) -> None:
"""Register prompt template with scene name, language
registry dict format:
@@ -37,7 +38,8 @@ class PromptTemplateRegistry:
}
}
"""
scene_name = prompt_template.template_scene
if not scene_name:
scene_name = prompt_template.template_scene
if not scene_name:
raise ValueError("Prompt template scene name cannot be empty")
if not model_names:

View File

@@ -45,6 +45,18 @@ class BaseMessage(BaseModel, ABC):
"round_index": self.round_index,
}
@staticmethod
def messages_to_string(messages: List["BaseMessage"]) -> str:
"""Convert messages to str
Args:
messages (List[BaseMessage]): The messages
Returns:
str: The str messages
"""
return _messages_to_str(messages)
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@@ -251,6 +263,41 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
return [_message_to_dict(m) for m in messages]
def _messages_to_str(
messages: List[BaseMessage],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
) -> str:
"""Convert messages to str
Args:
messages (List[BaseMessage]): The messages
human_prefix (str): The human prefix
ai_prefix (str): The ai prefix
system_prefix (str): The system prefix
Returns:
str: The str messages
"""
str_messages = []
for message in messages:
role = None
if isinstance(message, HumanMessage):
role = human_prefix
elif isinstance(message, AIMessage):
role = ai_prefix
elif isinstance(message, SystemMessage):
role = system_prefix
elif isinstance(message, ViewMessage):
pass
else:
raise ValueError(f"Got unsupported message type: {message}")
if role:
str_messages.append(f"{role}: {message.content}")
return "\n".join(str_messages)
def _message_from_dict(message: Dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
@@ -382,6 +429,9 @@ class OnceConversation:
self._message_index += 1
message.index = index
message.round_index = self.chat_order
message.additional_kwargs["param_type"] = self.param_type
message.additional_kwargs["param_value"] = self.param_value
message.additional_kwargs["model_name"] = self.model_name
self.messages.append(message)
def start_new_round(self) -> None:
@@ -504,9 +554,12 @@ class OnceConversation:
self.messages = conversation.messages
self.start_date = conversation.start_date
self.chat_order = conversation.chat_order
self.model_name = conversation.model_name
self.param_type = conversation.param_type
self.param_value = conversation.param_value
if not self.model_name and conversation.model_name:
self.model_name = conversation.model_name
if not self.param_type and conversation.param_type:
self.param_type = conversation.param_type
if not self.param_value and conversation.param_value:
self.param_value = conversation.param_value
self.cost = conversation.cost
self.tokens = conversation.tokens
self.user_name = conversation.user_name
@@ -801,6 +854,7 @@ class StorageConversation(OnceConversation, StorageItem):
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
load_message: bool = True,
**kwargs,
):
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
@@ -811,6 +865,8 @@ class StorageConversation(OnceConversation, StorageItem):
self._has_stored_message_index = (
len(kwargs["messages"]) - 1 if "messages" in kwargs else -1
)
# Whether to load the message from the storage
self._load_message = load_message
self.save_message_independent = save_message_independent
self._id = ConversationIdentifier(conv_uid)
if conv_storage is None:
@@ -853,7 +909,9 @@ class StorageConversation(OnceConversation, StorageItem):
]
messages_to_save = message_list[self._has_stored_message_index + 1 :]
self._has_stored_message_index = len(message_list) - 1
self.message_storage.save_list(messages_to_save)
if self.save_message_independent:
# Save messages independently
self.message_storage.save_list(messages_to_save)
# Save conversation
self.conv_storage.save_or_update(self)
@@ -876,23 +934,71 @@ class StorageConversation(OnceConversation, StorageItem):
return
message_ids = conversation._message_ids or []
# Load messages
message_list = message_storage.load_list(
[
MessageIdentifier.from_str_identifier(message_id)
for message_id in message_ids
],
MessageStorageItem,
)
messages = [message.to_message() for message in message_list]
conversation.messages = messages
if self._load_message:
# Load messages
message_list = message_storage.load_list(
[
MessageIdentifier.from_str_identifier(message_id)
for message_id in message_ids
],
MessageStorageItem,
)
messages = [message.to_message() for message in message_list]
else:
messages = []
real_messages = messages or conversation.messages
conversation.messages = real_messages
# This index is used to save the message to the storage(Has not been saved)
# The new message append to the messages, so the index is len(messages)
conversation._message_index = len(messages)
conversation._message_index = len(real_messages)
conversation.chat_order = (
max(m.round_index for m in real_messages) if real_messages else 0
)
self._append_additional_kwargs(conversation, real_messages)
self._message_ids = message_ids
self._has_stored_message_index = len(messages) - 1
self._has_stored_message_index = len(real_messages) - 1
self.save_message_independent = conversation.save_message_independent
self.from_conversation(conversation)
def _append_additional_kwargs(
self, conversation: StorageConversation, messages: List[BaseMessage]
) -> None:
"""Parse the additional kwargs and append to the conversation
Args:
conversation (StorageConversation): The conversation
messages (List[BaseMessage]): The messages
"""
param_type = None
param_value = None
for message in messages[::-1]:
if message.additional_kwargs:
param_type = message.additional_kwargs.get("param_type")
param_value = message.additional_kwargs.get("param_value")
break
if not conversation.param_type:
conversation.param_type = param_type
if not conversation.param_value:
conversation.param_value = param_value
def delete(self) -> None:
"""Delete all the messages and conversation from the storage"""
# Delete messages first
message_list = self._get_message_items()
message_ids = [message.identifier for message in message_list]
self.message_storage.delete_list(message_ids)
# Delete conversation
self.conv_storage.delete(self.identifier)
# Overwrite the current conversation with empty conversation
self.from_conversation(
StorageConversation(
self.conv_uid,
save_message_independent=self.save_message_independent,
conv_storage=self.conv_storage,
message_storage=self.message_storage,
)
)
def _conversation_to_dict(once: OnceConversation) -> Dict:
start_str: str = ""
@@ -937,3 +1043,61 @@ def _conversation_from_dict(once: dict) -> OnceConversation:
print(once.get("messages"))
conversation.messages = _messages_from_dict(once.get("messages", []))
return conversation
def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessage]]:
"""Split the messages by round index.
Args:
messages (List[BaseMessage]): The messages.
Returns:
List[List[BaseMessage]]: The messages split by round.
"""
messages_by_round: List[List[BaseMessage]] = []
last_round_index = 0
for message in messages:
if not message.round_index:
# Round index must bigger than 0
raise ValueError("Message round_index is not set")
if message.round_index > last_round_index:
last_round_index = message.round_index
messages_by_round.append([])
messages_by_round[-1].append(message)
return messages_by_round
def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
"""Append the view message to the messages
Just for show in DB-GPT-Web.
If already have view message, do nothing.
Args:
messages (List[BaseMessage]): The messages
Returns:
List[BaseMessage]: The messages with view message
"""
messages_by_round = _split_messages_by_round(messages)
for current_round in messages_by_round:
ai_message = None
view_message = None
for message in current_round:
if message.type == "ai":
ai_message = message
elif message.type == "view":
view_message = message
if view_message:
# Already have view message, do nothing
continue
if ai_message:
view_message = ViewMessage(
content=ai_message.content,
index=ai_message.index,
round_index=ai_message.round_index,
additional_kwargs=ai_message.additional_kwargs.copy()
if ai_message.additional_kwargs
else {},
)
current_round.append(view_message)
return sum(messages_by_round, [])

View File

@@ -45,7 +45,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
self,
prompt_template: ChatPromptTemplate,
history_key: str = "chat_history",
last_k_round: int = 2,
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
storage: Optional[StorageInterface[StorageConversation, Any]] = None,
message_storage: Optional[StorageInterface[MessageStorageItem, Any]] = None,
**kwargs,
@@ -53,7 +54,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
super().__init__(**kwargs)
self._prompt_template = prompt_template
self._history_key = history_key
self._last_k_round = last_k_round
self._keep_start_rounds = keep_start_rounds
self._keep_end_rounds = keep_end_rounds
self._storage = storage
self._message_storage = message_storage
self._sub_compose_dag = self._build_composer_dag()
@@ -74,7 +76,8 @@ class ChatHistoryPromptComposerOperator(MapOperator[ChatComposerInput, ModelRequ
)
# History transform task, here we keep last 5 round messages
history_transform_task = BufferedConversationMapperOperator(
last_k_round=self._last_k_round
keep_start_rounds=self._keep_start_rounds,
keep_end_rounds=self._keep_end_rounds,
)
history_prompt_build_task = HistoryPromptBuilderOperator(
prompt=self._prompt_template, history_key=self._history_key

View File

@@ -1,8 +1,9 @@
import uuid
from abc import ABC
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
from dbgpt.core import (
LLMClient,
MessageStorageItem,
ModelMessage,
ModelMessageRoleType,
@@ -11,7 +12,12 @@ from dbgpt.core import (
StorageInterface,
)
from dbgpt.core.awel import BaseOperator, MapOperator
from dbgpt.core.interface.message import BaseMessage, _MultiRoundMessageMapper
from dbgpt.core.interface.message import (
BaseMessage,
_messages_to_str,
_MultiRoundMessageMapper,
_split_messages_by_round,
)
class BaseConversationOperator(BaseOperator, ABC):
@@ -31,7 +37,6 @@ class BaseConversationOperator(BaseOperator, ABC):
**kwargs,
):
self._check_storage = check_storage
super().__init__(**kwargs)
self._storage = storage
self._message_storage = message_storage
@@ -167,12 +172,10 @@ class ConversationMapperOperator(
self._message_mapper = message_mapper
async def map(self, input_value: List[BaseMessage]) -> List[BaseMessage]:
return self.map_messages(input_value)
return await self.map_messages(input_value)
def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
messages_by_round: List[List[BaseMessage]] = self._split_messages_by_round(
messages
)
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
@@ -233,93 +236,66 @@ class ConversationMapperOperator(
Args:
"""
# Just merge and return
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
return sum(messages_by_round, [])
def _split_messages_by_round(
self, messages: List[BaseMessage]
) -> List[List[BaseMessage]]:
"""Split the messages by round index.
Args:
messages (List[BaseMessage]): The messages.
Returns:
List[List[BaseMessage]]: The messages split by round.
"""
messages_by_round: List[List[BaseMessage]] = []
last_round_index = 0
for message in messages:
if not message.round_index:
# Round index must bigger than 0
raise ValueError("Message round_index is not set")
if message.round_index > last_round_index:
last_round_index = message.round_index
messages_by_round.append([])
messages_by_round[-1].append(message)
return messages_by_round
return _merge_multi_round_messages(messages_by_round)
class BufferedConversationMapperOperator(ConversationMapperOperator):
"""The buffered conversation mapper operator.
"""
The buffered conversation mapper operator which can be configured to keep
a certain number of starting and/or ending rounds of a conversation.
This Operator must be used after the PreChatHistoryLoadOperator,
and it will map the messages in the storage conversation.
Args:
keep_start_rounds (Optional[int]): Number of initial rounds to keep.
keep_end_rounds (Optional[int]): Number of final rounds to keep.
Examples:
# Keeping the first 2 and the last 1 rounds of a conversation
import asyncio
from dbgpt.core.interface.message import AIMessage, HumanMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
Transform no history messages
.. code-block:: python
from dbgpt.core import ModelMessage
from dbgpt.core.operator import BufferedConversationMapperOperator
# No history
messages = [ModelMessage(role="human", content="Hello", round_index=1)]
operator = BufferedConversationMapperOperator(last_k_round=1)
assert operator.map_messages(messages) == [
ModelMessage(role="human", content="Hello", round_index=1)
]
Transform with history messages
.. code-block:: python
# With history
messages = [
ModelMessage(role="human", content="Hi", round_index=1),
ModelMessage(role="ai", content="Hello!", round_index=1),
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
operator = BufferedConversationMapperOperator(last_k_round=1)
# Just keep the last one round, so the first round messages will be removed
# Note: The round index 3 is not a complete round
assert operator.map_messages(messages) == [
ModelMessage(role="system", content="Error 404", round_index=2),
ModelMessage(role="human", content="What's the error?", round_index=2),
ModelMessage(role="ai", content="Just a joke.", round_index=2),
ModelMessage(role="human", content="Funny!", round_index=3),
]
operator = BufferedConversationMapperOperator(keep_start_rounds=2, keep_end_rounds=1)
messages = [
# Assume each HumanMessage and AIMessage belongs to separate rounds
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
# This will keep rounds 1, 2, and 3
assert asyncio.run(operator.map_messages(messages)) == [
HumanMessage(content="Hi", round_index=1),
AIMessage(content="Hello!", round_index=1),
HumanMessage(content="How are you?", round_index=2),
AIMessage(content="I'm good, thanks!", round_index=2),
HumanMessage(content="What's new today?", round_index=3),
AIMessage(content="Lots of things!", round_index=3),
]
"""
def __init__(
self,
last_k_round: Optional[int] = 2,
keep_start_rounds: Optional[int] = None,
keep_end_rounds: Optional[int] = None,
message_mapper: _MultiRoundMessageMapper = None,
**kwargs,
):
self._last_k_round = last_k_round
# Validate the input parameters
if keep_start_rounds is not None and keep_start_rounds < 0:
raise ValueError("keep_start_rounds must be non-negative")
if keep_end_rounds is not None and keep_end_rounds < 0:
raise ValueError("keep_end_rounds must be non-negative")
self._keep_start_rounds = keep_start_rounds
self._keep_end_rounds = keep_end_rounds
if message_mapper:
def new_message_mapper(
messages_by_round: List[List[BaseMessage]],
) -> List[BaseMessage]:
# Apply keep k round messages first, then apply the custom message mapper
messages_by_round = self._keep_last_round_messages(messages_by_round)
messages_by_round = self._filter_round_messages(messages_by_round)
return message_mapper(messages_by_round)
else:
@@ -327,21 +303,189 @@ class BufferedConversationMapperOperator(ConversationMapperOperator):
def new_message_mapper(
messages_by_round: List[List[BaseMessage]],
) -> List[BaseMessage]:
messages_by_round = self._keep_last_round_messages(messages_by_round)
return sum(messages_by_round, [])
messages_by_round = self._filter_round_messages(messages_by_round)
return _merge_multi_round_messages(messages_by_round)
super().__init__(new_message_mapper, **kwargs)
def _keep_last_round_messages(
def _filter_round_messages(
self, messages_by_round: List[List[BaseMessage]]
) -> List[List[BaseMessage]]:
"""Keep the last k round messages.
"""Filters the messages to keep only the specified starting and/or ending rounds.
Examples:
>>> from dbgpt.core import AIMessage, HumanMessage
>>> from dbgpt.core.operator import BufferedConversationMapperOperator
>>> messages = [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
# Test keeping only the first 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_start_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... ]
# Test keeping only the last 2 rounds
>>> operator = BufferedConversationMapperOperator(keep_end_rounds=2)
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
# Test keeping the first 2 and last 1 rounds
>>> operator = BufferedConversationMapperOperator(
... keep_start_rounds=2, keep_end_rounds=1
... )
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
# Test without specifying start or end rounds (keep all rounds)
>>> operator = BufferedConversationMapperOperator()
>>> assert operator._filter_round_messages(messages) == [
... [
... HumanMessage(content="Hi", round_index=1),
... AIMessage(content="Hello!", round_index=1),
... ],
... [
... HumanMessage(content="How are you?", round_index=2),
... AIMessage(content="I'm good, thanks!", round_index=2),
... ],
... [
... HumanMessage(content="What's new today?", round_index=3),
... AIMessage(content="Lots of things!", round_index=3),
... ],
... ]
Args:
messages_by_round (List[List[BaseMessage]]): The messages grouped by round.
Returns:
List[List[BaseMessage]]: Filtered list of messages.
"""
total_rounds = len(messages_by_round)
if self._keep_start_rounds is not None and self._keep_end_rounds is not None:
if self._keep_start_rounds + self._keep_end_rounds > total_rounds:
# Avoid overlapping when the sum of start and end rounds exceeds total rounds
return messages_by_round
return (
messages_by_round[: self._keep_start_rounds]
+ messages_by_round[-self._keep_end_rounds :]
)
elif self._keep_start_rounds is not None:
return messages_by_round[: self._keep_start_rounds]
elif self._keep_end_rounds is not None:
return messages_by_round[-self._keep_end_rounds :]
else:
return messages_by_round
EvictionPolicyType = Callable[[List[List[BaseMessage]]], List[List[BaseMessage]]]
class TokenBufferedConversationMapperOperator(ConversationMapperOperator):
"""The token buffered conversation mapper operator.
If the token count of the messages is greater than the max token limit, we will evict the messages by round.
Args:
model (str): The model name.
llm_client (LLMClient): The LLM client.
max_token_limit (int): The max token limit.
eviction_policy (EvictionPolicyType): The eviction policy.
message_mapper (_MultiRoundMessageMapper): The message mapper, it applies after all messages are handled.
"""
def __init__(
self,
model: str,
llm_client: LLMClient,
max_token_limit: int = 2000,
eviction_policy: EvictionPolicyType = None,
message_mapper: _MultiRoundMessageMapper = None,
**kwargs,
):
if max_token_limit < 0:
raise ValueError("Max token limit can't be negative")
self._model = model
self._llm_client = llm_client
self._max_token_limit = max_token_limit
self._eviction_policy = eviction_policy
self._message_mapper = message_mapper
super().__init__(**kwargs)
async def map_messages(self, messages: List[BaseMessage]) -> List[BaseMessage]:
eviction_policy = self._eviction_policy or self.eviction_policy
messages_by_round: List[List[BaseMessage]] = _split_messages_by_round(messages)
messages_str = _messages_to_str(_merge_multi_round_messages(messages_by_round))
# Fist time, we count the token of the messages
current_tokens = await self._llm_client.count_token(self._model, messages_str)
while current_tokens > self._max_token_limit:
# Evict the messages by round after all tokens are not greater than the max token limit
# TODO: We should find a high performance way to do this
messages_by_round = eviction_policy(messages_by_round)
messages_str = _messages_to_str(
_merge_multi_round_messages(messages_by_round)
)
current_tokens = await self._llm_client.count_token(
self._model, messages_str
)
message_mapper = self._message_mapper or self.map_multi_round_messages
return message_mapper(messages_by_round)
def eviction_policy(
self, messages_by_round: List[List[BaseMessage]]
) -> List[List[BaseMessage]]:
"""Evict the messages by round, default is FIFO.
Args:
messages_by_round (List[List[BaseMessage]]): The messages by round.
Returns:
List[List[BaseMessage]]: The latest round messages.
List[List[BaseMessage]]: The evicted messages by round.
"""
index = self._last_k_round + 1
return messages_by_round[-index:]
messages_by_round.pop(0)
return messages_by_round
def _merge_multi_round_messages(messages: List[List[BaseMessage]]) -> List[BaseMessage]:
# e.g. assert sum([[1, 2], [3, 4], [5, 6]], []) == [1, 2, 3, 4, 5, 6]
return sum(messages, [])

View File

@@ -216,18 +216,27 @@ class HistoryPromptBuilderOperator(
BasePromptBuilderOperator, JoinOperator[List[ModelMessage]]
):
def __init__(
self, prompt: ChatPromptTemplate, history_key: Optional[str] = None, **kwargs
self,
prompt: ChatPromptTemplate,
history_key: Optional[str] = None,
check_storage: bool = True,
str_history: bool = False,
**kwargs,
):
self._prompt = prompt
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
async def merge_history(
self, history: List[BaseMessage], prompt_dict: Dict[str, Any]
) -> List[ModelMessage]:
prompt_dict[self._history_key] = history
if self._str_history:
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
else:
prompt_dict[self._history_key] = history
return await self.format_prompt(self._prompt, prompt_dict)
@@ -239,9 +248,16 @@ class HistoryDynamicPromptBuilderOperator(
The prompt template is dynamic, and it created by parent operator.
"""
def __init__(self, history_key: Optional[str] = None, **kwargs):
def __init__(
self,
history_key: Optional[str] = None,
check_storage: bool = True,
str_history: bool = False,
**kwargs,
):
self._history_key = history_key
self._str_history = str_history
BasePromptBuilderOperator.__init__(self, check_storage=check_storage)
JoinOperator.__init__(self, combine_function=self.merge_history, **kwargs)
@rearrange_args_by_type
@@ -251,5 +267,8 @@ class HistoryDynamicPromptBuilderOperator(
history: List[BaseMessage],
prompt_dict: Dict[str, Any],
) -> List[ModelMessage]:
prompt_dict[self._history_key] = history
if self._str_history:
prompt_dict[self._history_key] = BaseMessage.messages_to_string(history)
else:
prompt_dict[self._history_key] = history
return await self.format_prompt(prompt, prompt_dict)

View File

@@ -49,13 +49,25 @@ class BasePromptTemplate(BaseModel):
"""The prompt template."""
template_format: Optional[str] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: Optional[str] = None
response_key: Optional[str] = "response"
template_is_strict: Optional[bool] = True
"""strict template will check template args"""
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
return _DEFAULT_FORMATTER_MAPPING[self.template_format](True)(
self.template, **kwargs
)
if self.response_format:
kwargs[self.response_key] = json.dumps(
self.response_format, ensure_ascii=False, indent=4
)
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
@classmethod
def from_template(
@@ -75,10 +87,6 @@ class PromptTemplate(BasePromptTemplate):
template_scene: Optional[str]
template_define: Optional[str]
"""this template define"""
"""strict template will check template args"""
template_is_strict: bool = True
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
response_format: Optional[str]
"""default use stream out"""
stream_out: bool = True
""""""
@@ -103,17 +111,6 @@ class PromptTemplate(BasePromptTemplate):
"""Return the prompt type key."""
return "prompt"
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs."""
if self.template:
if self.response_format:
kwargs["response"] = json.dumps(
self.response_format, ensure_ascii=False, indent=4
)
return _DEFAULT_FORMATTER_MAPPING[self.template_format](
self.template_is_strict
)(self.template, **kwargs)
class BaseChatPromptTemplate(BaseModel, ABC):
prompt: BasePromptTemplate
@@ -129,10 +126,22 @@ class BaseChatPromptTemplate(BaseModel, ABC):
@classmethod
def from_template(
cls, template: str, template_format: Optional[str] = "f-string", **kwargs: Any
cls,
template: str,
template_format: Optional[str] = "f-string",
response_format: Optional[str] = None,
response_key: Optional[str] = "response",
template_is_strict: bool = True,
**kwargs: Any,
) -> BaseChatPromptTemplate:
"""Create a prompt template from a template string."""
prompt = BasePromptTemplate.from_template(template, template_format)
prompt = BasePromptTemplate.from_template(
template,
template_format,
response_format=response_format,
response_key=response_key,
template_is_strict=template_is_strict,
)
return cls(prompt=prompt, **kwargs)

View File

@@ -284,6 +284,15 @@ class StorageInterface(Generic[T, TDataRepresentation], ABC):
resource_id (ResourceIdentifier): The resource identifier of the data
"""
def delete_list(self, resource_id: List[ResourceIdentifier]) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
"""
for r in resource_id:
self.delete(r)
@abstractmethod
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.

View File

@@ -138,14 +138,14 @@ def test_clear_messages(basic_conversation, human_message):
def test_get_latest_user_message(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
latest_message = basic_conversation.get_latest_user_message()
assert latest_message == human_message
assert latest_message.content == human_message.content
def test_get_system_messages(basic_conversation, system_message):
basic_conversation.add_system_message(system_message.content)
system_messages = basic_conversation.get_system_messages()
assert len(system_messages) == 1
assert system_messages[0] == system_message
assert system_messages[0].content == system_message.content
def test_from_conversation(basic_conversation):
@@ -324,6 +324,35 @@ def test_load_from_storage(storage_conversation, in_memory_storage):
assert isinstance(new_conversation.messages[1], AIMessage)
def test_delete(storage_conversation, in_memory_storage):
# Set storage
storage_conversation.conv_storage = in_memory_storage
storage_conversation.message_storage = in_memory_storage
# Add messages and save to storage
storage_conversation.start_new_round()
storage_conversation.add_user_message("User message")
storage_conversation.add_ai_message("AI response")
storage_conversation.end_current_round()
# Create a new StorageConversation instance to load the data
new_conversation = StorageConversation(
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
)
# Delete the conversation
new_conversation.delete()
# Check if the conversation is deleted
assert new_conversation.conv_uid == storage_conversation.conv_uid
assert len(new_conversation.messages) == 0
no_messages_conv = StorageConversation(
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
)
assert len(no_messages_conv.messages) == 0
def test_parse_model_messages_no_history_messages():
messages = [
ModelMessage(role=ModelMessageRoleType.HUMAN, content="Hello"),

View File

@@ -14,6 +14,7 @@ from dbgpt.core.interface.operator.message_operator import (
BufferedConversationMapperOperator,
ConversationMapperOperator,
PreChatHistoryLoadOperator,
TokenBufferedConversationMapperOperator,
)
from dbgpt.core.interface.operator.prompt_operator import (
DynamicPromptBuilderOperator,
@@ -30,6 +31,7 @@ __ALL__ = [
"BaseStreamingLLMOperator",
"BaseConversationOperator",
"BufferedConversationMapperOperator",
"TokenBufferedConversationMapperOperator",
"ConversationMapperOperator",
"PreChatHistoryLoadOperator",
"PromptBuilderOperator",