Files
DB-GPT/dbgpt/core/interface/message.py
2023-12-25 20:03:22 +08:00

851 lines
30 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.awel import MapOperator
from dbgpt.core.interface.storage import (
InMemoryStorage,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
class BaseMessage(BaseModel, ABC):
"""Message object."""
content: str
index: int = 0
round_index: int = 0
"""The round index of the message in the conversation"""
additional_kwargs: dict = Field(default_factory=dict)
@property
@abstractmethod
def type(self) -> str:
"""Type of the message, used for serialization."""
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model"""
return True
def to_dict(self) -> Dict:
"""Convert to dict
Returns:
Dict: The dict object
"""
return {
"type": self.type,
"data": self.dict(),
"index": self.index,
"round_index": self.round_index,
}
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "human"
class AIMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "ai"
class ViewMessage(BaseMessage):
"""Type of message that is spoken by the AI."""
example: bool = False
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "view"
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
The view message will not be passed to the model
"""
return False
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@property
def type(self) -> str:
"""Type of the message, used for serialization."""
return "system"
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
SYSTEM = "system"
HUMAN = "human"
AI = "ai"
VIEW = "view"
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Similar to openai's message format"""
role: str
content: str
round_index: Optional[int] = 0
@staticmethod
def from_openai_messages(
messages: Union[str, List[Dict[str, str]]]
) -> List["ModelMessage"]:
"""Openai message format to current ModelMessage format"""
if isinstance(messages, str):
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
result = []
for message in messages:
msg_role = message["role"]
content = message["content"]
if msg_role == "system":
result.append(
ModelMessage(role=ModelMessageRoleType.SYSTEM, content=content)
)
elif msg_role == "user":
result.append(
ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
)
elif msg_role == "assistant":
result.append(
ModelMessage(role=ModelMessageRoleType.AI, content=content)
)
else:
raise ValueError(f"Unknown role: {msg_role}")
return result
@staticmethod
def to_openai_messages(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
"""Convert to OpenAI message format and
hugggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""
history = []
# Add history conversation
for message in messages:
if message.role == ModelMessageRoleType.HUMAN:
history.append({"role": "user", "content": message.content})
elif message.role == ModelMessageRoleType.SYSTEM:
history.append({"role": "system", "content": message.content})
elif message.role == ModelMessageRoleType.AI:
history.append({"role": "assistant", "content": message.content})
else:
pass
# Move the last user's information to the end
last_user_input_index = None
for i in range(len(history) - 1, -1, -1):
if history[i]["role"] == "user":
last_user_input_index = i
break
if last_user_input_index:
last_user_input = history.pop(last_user_input_index)
history.append(last_user_input)
return history
@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
return list(map(lambda m: m.dict(), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
def _message_to_dict(message: BaseMessage) -> Dict:
return message.to_dict()
def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: Dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
elif _type == "ai":
return AIMessage(**message["data"])
elif _type == "system":
return SystemMessage(**message["data"])
elif _type == "view":
return ViewMessage(**message["data"])
else:
raise ValueError(f"Got unexpected type: {_type}")
def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
def parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
"""
Parse model messages to extract the user prompt, system messages, and a history of conversation.
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
pairs of human and AI messages.
Args:
messages (List[ModelMessage]): List of messages from a chat conversation.
Returns:
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
Examples:
.. code-block:: python
# Example 1: Single round of conversation
messages = [
ModelMessage(role="human", content="Hello"),
ModelMessage(role="ai", content="Hi there!"),
ModelMessage(role="human", content="How are you?"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "How are you?"
# system_messages: []
# history: [["Hello", "Hi there!"]]
# Example 2: Conversation with system messages
messages = [
ModelMessage(role="system", content="System initializing..."),
ModelMessage(role="human", content="Is it sunny today?"),
ModelMessage(role="ai", content="Yes, it's sunny."),
ModelMessage(role="human", content="Great!"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "Great!"
# system_messages: ["System initializing..."]
# history: [["Is it sunny today?", "Yes, it's sunny."]]
# Example 3: Multiple rounds with system message
messages = [
ModelMessage(role="human", content="Hi"),
ModelMessage(role="ai", content="Hello!"),
ModelMessage(role="system", content="Error 404"),
ModelMessage(role="human", content="What's the error?"),
ModelMessage(role="ai", content="Just a joke."),
ModelMessage(role="human", content="Funny!"),
]
user_prompt, system_messages, history = parse_model_messages(messages)
# user_prompt: "Funny!"
# system_messages: ["Error 404"]
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
"""
system_messages: List[str] = []
history_messages: List[List[str]] = [[]]
for message in messages[:-1]:
if message.role == "human":
history_messages[-1].append(message.content)
elif message.role == "system":
system_messages.append(message.content)
elif message.role == "ai":
history_messages[-1].append(message.content)
history_messages.append([])
if messages[-1].role != "human":
raise ValueError("Hi! What do you want to talk about")
# Keep message a pair of [user message, assistant message]
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
user_prompt = messages[-1].content
return user_prompt, system_messages, history_messages
class OnceConversation:
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
def __init__(
self,
chat_mode: str,
user_name: str = None,
sys_code: str = None,
summary: str = None,
**kwargs,
):
self.chat_mode: str = chat_mode
self.user_name: str = user_name
self.sys_code: str = sys_code
self.summary: str = summary
self.messages: List[BaseMessage] = kwargs.get("messages", [])
self.start_date: str = kwargs.get("start_date", "")
# After each complete round of dialogue, the current value will be increased by 1
self.chat_order: int = int(kwargs.get("chat_order", 0))
self.model_name: str = kwargs.get("model_name", "")
self.param_type: str = kwargs.get("param_type", "")
self.param_value: str = kwargs.get("param_value", "")
self.cost: int = int(kwargs.get("cost", 0))
self.tokens: int = int(kwargs.get("tokens", 0))
self._message_index: int = int(kwargs.get("message_index", 0))
def _append_message(self, message: BaseMessage) -> None:
index = self._message_index
self._message_index += 1
message.index = index
message.round_index = self.chat_order
self.messages.append(message)
def start_new_round(self) -> None:
"""Start a new round of conversation
Example:
>>> conversation = OnceConversation()
>>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0
>>> conversation.start_new_round()
>>> # Now the chat order will be 1
>>> assert conversation.chat_order == 1
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> # Now the chat order will be 1, then we start a new round of conversation
>>> conversation.start_new_round()
>>> # Now the chat order will be 2
>>> assert conversation.chat_order == 2
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> assert conversation.chat_order == 2
"""
self.chat_order += 1
def end_current_round(self) -> None:
"""End the current round of conversation
We do noting here, just for the interface
"""
pass
def add_user_message(
self, message: str, check_duplicate_type: Optional[bool] = False
) -> None:
"""Add a user message to the conversation
Args:
message (str): The message content
check_duplicate_type (bool): Whether to check the duplicate message type
Raises:
ValueError: If the message is duplicate and check_duplicate_type is True
"""
if check_duplicate_type:
has_message = any(
isinstance(instance, HumanMessage) for instance in self.messages
)
if has_message:
raise ValueError("Already Have Human message")
self._append_message(HumanMessage(content=message))
def add_ai_message(
self, message: str, update_if_exist: Optional[bool] = False
) -> None:
"""Add an AI message to the conversation
Args:
message (str): The message content
update_if_exist (bool): Whether to update the message if the message type is duplicate
"""
if not update_if_exist:
self._append_message(AIMessage(content=message))
return
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message:
self._update_ai_message(message)
else:
self._append_message(AIMessage(content=message))
def _update_ai_message(self, new_message: str) -> None:
"""
stream out message update
Args:
new_message:
Returns:
"""
for item in self.messages:
if item.type == "ai":
item.content = new_message
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
self._append_message(ViewMessage(content=message))
def add_system_message(self, message: str) -> None:
"""Add a system message to the store"""
self._append_message(SystemMessage(content=message))
def set_start_time(self, datatime: datetime):
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
self.start_date = dt_str
def clear(self) -> None:
"""Remove all messages from the store"""
self.messages.clear()
def get_latest_user_message(self) -> Optional[HumanMessage]:
"""Get the latest user message"""
for message in self.messages[::-1]:
if isinstance(message, HumanMessage):
return message
return None
def get_system_messages(self) -> List[SystemMessage]:
"""Get the latest user message"""
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
def _to_dict(self) -> Dict:
return _conversation_to_dict(self)
def from_conversation(self, conversation: OnceConversation) -> None:
"""Load the conversation from the storage"""
self.chat_mode = conversation.chat_mode
self.messages = conversation.messages
self.start_date = conversation.start_date
self.chat_order = conversation.chat_order
self.model_name = conversation.model_name
self.param_type = conversation.param_type
self.param_value = conversation.param_value
self.cost = conversation.cost
self.tokens = conversation.tokens
self.user_name = conversation.user_name
self.sys_code = conversation.sys_code
self._message_index = conversation._message_index
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
Args:
round_index (int): The round index
Returns:
List[BaseMessage]: The messages
"""
return list(filter(lambda x: x.round_index == round_index, self.messages))
def get_latest_round(self) -> List[BaseMessage]:
"""Get the latest round messages
Returns:
List[BaseMessage]: The messages
"""
return self.get_messages_by_round(self.chat_order)
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
"""Get the messages with round count
If the round count is 1, the history messages will not be included.
Example:
.. code-block:: python
conversation = OnceConversation()
conversation.start_new_round()
conversation.add_user_message("hello, this is the first round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is the second round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is the third round")
conversation.add_ai_message("hi")
conversation.end_current_round()
assert len(conversation.get_messages_with_round(1)) == 2
assert (
conversation.get_messages_with_round(1)[0].content
== "hello, this is the third round"
)
assert conversation.get_messages_with_round(1)[1].content == "hi"
assert len(conversation.get_messages_with_round(2)) == 4
assert (
conversation.get_messages_with_round(2)[0].content
== "hello, this is the second round"
)
assert conversation.get_messages_with_round(2)[1].content == "hi"
Args:
round_count (int): The round count
Returns:
List[BaseMessage]: The messages
"""
latest_round_index = self.chat_order
start_round_index = max(1, latest_round_index - round_count + 1)
messages = []
for round_index in range(start_round_index, latest_round_index + 1):
messages.extend(self.get_messages_by_round(round_index))
return messages
def get_model_messages(self) -> List[ModelMessage]:
"""Get the model messages
Model messages just include human, ai and system messages.
Model messages maybe include the history messages, The order of the messages is the same as the order of
the messages in the conversation, the last message is the latest message.
If you want to hand the message with your own logic, you can override this method.
Examples:
If you not need the history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
for message in self.get_latest_round():
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
)
return messages
If you want to add the one round history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
latest_round_index = self.chat_order
round_count = 1
start_round_index = max(1, latest_round_index - round_count + 1)
for round_index in range(start_round_index, latest_round_index + 1):
for message in self.get_messages_by_round(round_index):
if message.pass_to_model:
messages.append(
ModelMessage(
role=message.type, content=message.content
)
)
return messages
Returns:
List[ModelMessage]: The model messages
"""
messages = []
for message in self.messages:
if message.pass_to_model:
messages.append(
ModelMessage(
role=message.type,
content=message.content,
round_index=message.round_index,
)
)
return messages
class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier"""
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
self.conv_uid = conv_uid
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}:{self.conv_uid}"
def to_dict(self) -> Dict:
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
class MessageIdentifier(ResourceIdentifier):
"""Message identifier"""
identifier_split = "___"
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
self.conv_uid = conv_uid
self.index = index
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
@staticmethod
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
"""Convert from str identifier
Args:
str_identifier (str): The str identifier
Returns:
MessageIdentifier: The message identifier
"""
parts = str_identifier.split(MessageIdentifier.identifier_split)
if len(parts) != 3:
raise ValueError(f"Invalid str identifier: {str_identifier}")
return MessageIdentifier(parts[1], int(parts[2]))
def to_dict(self) -> Dict:
return {
"conv_uid": self.conv_uid,
"index": self.index,
"identifier_type": self.identifier_type,
}
class MessageStorageItem(StorageItem):
@property
def identifier(self) -> MessageIdentifier:
return self._id
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
self.conv_uid = conv_uid
self.index = index
self.message_detail = message_detail
self._id = MessageIdentifier(conv_uid, index)
def to_dict(self) -> Dict:
return {
"conv_uid": self.conv_uid,
"index": self.index,
"message_detail": self.message_detail,
}
def to_message(self) -> BaseMessage:
"""Convert to message object
Returns:
BaseMessage: The message object
Raises:
ValueError: If the message type is not supported
"""
return _message_from_dict(self.message_detail)
def merge(self, other: "StorageItem") -> None:
"""Merge the other message to self
Args:
other (StorageItem): The other message
"""
if not isinstance(other, MessageStorageItem):
raise ValueError(f"Can not merge {other} to {self}")
self.message_detail = other.message_detail
class StorageConversation(OnceConversation, StorageItem):
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
@property
def identifier(self) -> ConversationIdentifier:
return self._id
def to_dict(self) -> Dict:
dict_data = self._to_dict()
messages: Dict = dict_data.pop("messages")
message_ids = []
index = 0
for message in messages:
if "index" in message:
message_idx = message["index"]
else:
message_idx = index
index += 1
message_ids.append(
MessageIdentifier(self.conv_uid, message_idx).str_identifier
)
# Replace message with message ids
dict_data["conv_uid"] = self.conv_uid
dict_data["message_ids"] = message_ids
dict_data["save_message_independent"] = self.save_message_independent
return dict_data
def merge(self, other: "StorageItem") -> None:
"""Merge the other conversation to self
Args:
other (StorageItem): The other conversation
"""
if not isinstance(other, StorageConversation):
raise ValueError(f"Can not merge {other} to {self}")
self.from_conversation(other)
def __init__(
self,
conv_uid: str,
chat_mode: str = None,
user_name: str = None,
sys_code: str = None,
message_ids: List[str] = None,
summary: str = None,
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
**kwargs,
):
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
self.conv_uid = conv_uid
self._message_ids = message_ids
# Record the message index last time saved to the storage,
# next time save messages which index is _has_stored_message_index + 1
self._has_stored_message_index = (
len(kwargs["messages"]) - 1 if "messages" in kwargs else -1
)
self.save_message_independent = save_message_independent
self._id = ConversationIdentifier(conv_uid)
if conv_storage is None:
conv_storage = InMemoryStorage()
if message_storage is None:
message_storage = InMemoryStorage()
self.conv_storage = conv_storage
self.message_storage = message_storage
# Load from storage
self.load_from_storage(self.conv_storage, self.message_storage)
@property
def message_ids(self) -> List[str]:
"""Get the message ids
Returns:
List[str]: The message ids
"""
return self._message_ids if self._message_ids else []
def end_current_round(self) -> None:
"""End the current round of conversation
Save the conversation to the storage after a round of conversation
"""
self.save_to_storage()
def _get_message_items(self) -> List[MessageStorageItem]:
return [
MessageStorageItem(self.conv_uid, message.index, message.to_dict())
for message in self.messages
]
def save_to_storage(self) -> None:
"""Save the conversation to the storage"""
# Save messages first
message_list = self._get_message_items()
self._message_ids = [
message.identifier.str_identifier for message in message_list
]
messages_to_save = message_list[self._has_stored_message_index + 1 :]
self._has_stored_message_index = len(message_list) - 1
self.message_storage.save_list(messages_to_save)
# Save conversation
self.conv_storage.save_or_update(self)
def load_from_storage(
self, conv_storage: StorageInterface, message_storage: StorageInterface
) -> None:
"""Load the conversation from the storage
Warning: This will overwrite the current conversation.
Args:
conv_storage (StorageInterface): The storage interface
message_storage (StorageInterface): The storage interface
"""
# Load conversation first
conversation: StorageConversation = conv_storage.load(
self._id, StorageConversation
)
if conversation is None:
return
message_ids = conversation._message_ids or []
# Load messages
message_list = message_storage.load_list(
[
MessageIdentifier.from_str_identifier(message_id)
for message_id in message_ids
],
MessageStorageItem,
)
messages = [message.to_message() for message in message_list]
conversation.messages = messages
# This index is used to save the message to the storage(Has not been saved)
# The new message append to the messages, so the index is len(messages)
conversation._message_index = len(messages)
self._message_ids = message_ids
self._has_stored_message_index = len(messages) - 1
self.from_conversation(conversation)
def _conversation_to_dict(once: OnceConversation) -> Dict:
start_str: str = ""
if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime):
start_str = once.start_date.strftime("%Y-%m-%d %H:%M:%S")
else:
start_str = once.start_date
return {
"chat_mode": once.chat_mode,
"model_name": once.model_name,
"chat_order": once.chat_order,
"start_date": start_str,
"cost": once.cost if once.cost else 0,
"tokens": once.tokens if once.tokens else 0,
"messages": _messages_to_dict(once.messages),
"param_type": once.param_type,
"param_value": once.param_value,
"user_name": once.user_name,
"sys_code": once.sys_code,
"summary": once.summary if once.summary else "",
}
def _conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
return [_conversation_to_dict(m) for m in conversations]
def _conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation(
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
)
conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order"))
conversation.param_type = once.get("param_type", "")
conversation.param_value = once.get("param_value", "")
conversation.model_name = once.get("model_name", "proxyllm")
print(once.get("messages"))
conversation.messages = _messages_from_dict(once.get("messages", []))
return conversation