chore: Add pylint for DB-GPT core lib (#1076)

This commit is contained in:
Fangyin Cheng
2024-01-16 17:36:26 +08:00
committed by GitHub
parent 3a54d1ef9a
commit 40c853575a
79 changed files with 2213 additions and 839 deletions

View File

@@ -1,8 +1,10 @@
"""The conversation and message module."""
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.interface.storage import (
@@ -29,11 +31,11 @@ class BaseMessage(BaseModel, ABC):
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model"""
"""Whether the message will be passed to the model."""
return True
def to_dict(self) -> Dict:
"""Convert to dict
"""Convert to dict.
Returns:
Dict: The dict object
@@ -47,7 +49,7 @@ class BaseMessage(BaseModel, ABC):
@staticmethod
def messages_to_string(messages: List["BaseMessage"]) -> str:
"""Convert messages to str
"""Convert messages to str.
Args:
messages (List[BaseMessage]): The messages
@@ -92,7 +94,7 @@ class ViewMessage(BaseMessage):
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
"""Whether the message will be passed to the model.
The view message will not be passed to the model
"""
@@ -109,7 +111,7 @@ class SystemMessage(BaseMessage):
class ModelMessageRoleType:
""" "Type of ModelMessage role"""
"""Type of ModelMessage role."""
SYSTEM = "system"
HUMAN = "human"
@@ -118,7 +120,7 @@ class ModelMessageRoleType:
class ModelMessage(BaseModel):
"""Type of message that interaction between dbgpt-server and llm-server"""
"""Type of message that interaction between dbgpt-server and llm-server."""
"""Similar to openai's message format"""
role: str
@@ -127,7 +129,7 @@ class ModelMessage(BaseModel):
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
"""Whether the message will be passed to the model.
The view message will not be passed to the model
@@ -142,6 +144,14 @@ class ModelMessage(BaseModel):
@staticmethod
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
"""Covert BaseMessage format to current ModelMessage format.
Args:
messages (List[BaseMessage]): The base messages
Returns:
List[ModelMessage]: The model messages
"""
result = []
for message in messages:
content, round_index = message.content, message.round_index
@@ -173,7 +183,7 @@ class ModelMessage(BaseModel):
def from_openai_messages(
messages: Union[str, List[Dict[str, str]]]
) -> List["ModelMessage"]:
"""Openai message format to current ModelMessage format"""
"""Openai message format to current ModelMessage format."""
if isinstance(messages, str):
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
result = []
@@ -202,8 +212,11 @@ class ModelMessage(BaseModel):
convert_to_compatible_format: bool = False,
support_system_role: bool = True,
) -> List[Dict[str, str]]:
"""Convert to common message format(e.g. OpenAI message format) and
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
"""Cover to common message format.
Convert to common message format(e.g. OpenAI message format) and
huggingface [Templates of Chat Models]
(https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
Args:
messages (List["ModelMessage"]): The model messages
@@ -243,15 +256,38 @@ class ModelMessage(BaseModel):
@staticmethod
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
"""Convert to dict list.
Args:
messages (List["ModelMessage"]): The model messages
Returns:
List[Dict[str, str]]: The dict list
"""
return list(map(lambda m: m.dict(), messages))
@staticmethod
def build_human_message(content: str) -> "ModelMessage":
"""Build human message.
Args:
content (str): The content
Returns:
ModelMessage: The model message
"""
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
@staticmethod
def get_printable_message(messages: List["ModelMessage"]) -> str:
"""Get the printable message"""
"""Get the printable message.
Args:
messages (List["ModelMessage"]): The model messages
Returns:
str: The printable message
"""
str_msg = ""
for message in messages:
curr_message = (
@@ -263,7 +299,7 @@ class ModelMessage(BaseModel):
@staticmethod
def messages_to_string(messages: List["ModelMessage"]) -> str:
"""Convert messages to str
"""Convert messages to str.
Args:
messages (List[ModelMessage]): The messages
@@ -287,12 +323,12 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
def _messages_to_str(
messages: List[Union[BaseMessage, ModelMessage]],
messages: Union[List[BaseMessage], List[ModelMessage]],
human_prefix: str = "Human",
ai_prefix: str = "AI",
system_prefix: str = "System",
) -> str:
"""Convert messages to str
"""Convert messages to str.
Args:
messages (List[Union[BaseMessage, ModelMessage]]): The messages
@@ -343,21 +379,27 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
def parse_model_messages(
messages: List[ModelMessage],
) -> Tuple[str, List[str], List[List[str, str]]]:
"""
Parse model messages to extract the user prompt, system messages, and a history of conversation.
) -> Tuple[str, List[str], List[List[str]]]:
"""Parse model messages.
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
pairs of human and AI messages.
Parse model messages to extract the user prompt, system messages, and a history of
conversation.
This function analyzes a list of ModelMessage objects, identifying the role of each
message (e.g., human, system, ai)
and categorizes them accordingly. The last message is expected to be from the user
(human), and it's treated as
the current user prompt. System messages are extracted separately, and the
conversation history is compiled into pairs of human and AI messages.
Args:
messages (List[ModelMessage]): List of messages from a chat conversation.
Returns:
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
tuple: A tuple containing the user prompt, list of system messages, and the
conversation history.
The conversation history is a list of message pairs, each containing a
user message and the corresponding AI response.
Examples:
.. code-block:: python
@@ -399,7 +441,6 @@ def parse_model_messages(
# system_messages: ["Error 404"]
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
"""
system_messages: List[str] = []
history_messages: List[List[str]] = [[]]
@@ -420,27 +461,30 @@ def parse_model_messages(
class OnceConversation:
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""Once conversation.
All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
def __init__(
self,
chat_mode: str,
user_name: str = None,
sys_code: str = None,
summary: str = None,
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
summary: Optional[str] = None,
**kwargs,
):
"""Create a new conversation."""
self.chat_mode: str = chat_mode
self.user_name: str = user_name
self.sys_code: str = sys_code
self.summary: str = summary
self.user_name: Optional[str] = user_name
self.sys_code: Optional[str] = sys_code
self.summary: Optional[str] = summary
self.messages: List[BaseMessage] = kwargs.get("messages", [])
self.start_date: str = kwargs.get("start_date", "")
# After each complete round of dialogue, the current value will be increased by 1
# After each complete round of dialogue, the current value will be
# increased by 1
self.chat_order: int = int(kwargs.get("chat_order", 0))
self.model_name: str = kwargs.get("model_name", "")
self.param_type: str = kwargs.get("param_type", "")
@@ -460,10 +504,9 @@ class OnceConversation:
self.messages.append(message)
def start_new_round(self) -> None:
"""Start a new round of conversation
"""Start a new round of conversation.
Example:
>>> conversation = OnceConversation("chat_normal")
>>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0
@@ -473,7 +516,8 @@ class OnceConversation:
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> # Now the chat order will be 1, then we start a new round of conversation
>>> # Now the chat order will be 1, then we start a new round of
>>> # conversation
>>> conversation.start_new_round()
>>> # Now the chat order will be 2
>>> assert conversation.chat_order == 2
@@ -485,7 +529,7 @@ class OnceConversation:
self.chat_order += 1
def end_current_round(self) -> None:
"""End the current round of conversation
"""Execute the end of the current round of conversation.
We do noting here, just for the interface
"""
@@ -494,7 +538,7 @@ class OnceConversation:
def add_user_message(
self, message: str, check_duplicate_type: Optional[bool] = False
) -> None:
"""Add a user message to the conversation
"""Save a user message to the conversation.
Args:
message (str): The message content
@@ -514,11 +558,12 @@ class OnceConversation:
def add_ai_message(
self, message: str, update_if_exist: Optional[bool] = False
) -> None:
"""Add an AI message to the conversation
"""Save an AI message to current conversation.
Args:
message (str): The message content
update_if_exist (bool): Whether to update the message if the message type is duplicate
update_if_exist (bool): Whether to update the message if the message type
is duplicate
"""
if not update_if_exist:
self._append_message(AIMessage(content=message))
@@ -530,51 +575,57 @@ class OnceConversation:
self._append_message(AIMessage(content=message))
def _update_ai_message(self, new_message: str) -> None:
"""
"""Update the all AI message to new message.
stream out message update
Args:
new_message:
Returns:
new_message (str): The new message
"""
for item in self.messages:
if item.type == "ai":
item.content = new_message
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
"""Save a view message to current conversation."""
self._append_message(ViewMessage(content=message))
def add_system_message(self, message: str) -> None:
"""Add a system message to the store"""
"""Save a system message to current conversation."""
self._append_message(SystemMessage(content=message))
def set_start_time(self, datatime: datetime):
"""Set the start time of the conversation."""
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
self.start_date = dt_str
def clear(self) -> None:
"""Remove all messages from the store"""
"""Remove all messages from the store."""
self.messages.clear()
def get_latest_user_message(self) -> Optional[HumanMessage]:
"""Get the latest user message"""
"""Get the latest user message."""
for message in self.messages[::-1]:
if isinstance(message, HumanMessage):
return message
return None
def get_system_messages(self) -> List[SystemMessage]:
"""Get the latest user message"""
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
"""Get the latest user message.
Returns:
List[SystemMessage]: The system messages
"""
return cast(
List[SystemMessage],
list(filter(lambda x: isinstance(x, SystemMessage), self.messages)),
)
def _to_dict(self) -> Dict:
return _conversation_to_dict(self)
def from_conversation(self, conversation: OnceConversation) -> None:
"""Load the conversation from the storage"""
"""Load the conversation from the storage."""
self.chat_mode = conversation.chat_mode
self.messages = conversation.messages
self.start_date = conversation.start_date
@@ -592,7 +643,7 @@ class OnceConversation:
self._message_index = conversation._message_index
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
"""Get the messages by round index.
Args:
round_index (int): The round index
@@ -603,7 +654,7 @@ class OnceConversation:
return list(filter(lambda x: x.round_index == round_index, self.messages))
def get_latest_round(self) -> List[BaseMessage]:
"""Get the latest round messages
"""Get the latest round messages.
Returns:
List[BaseMessage]: The messages
@@ -611,7 +662,7 @@ class OnceConversation:
return self.get_messages_by_round(self.chat_order)
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
"""Get the messages with round count
"""Get the messages with round count.
If the round count is 1, the history messages will not be included.
@@ -660,16 +711,19 @@ class OnceConversation:
return messages
def get_model_messages(self) -> List[ModelMessage]:
"""Get the model messages
"""Get the model messages.
Model messages just include human, ai and system messages.
Model messages maybe include the history messages, The order of the messages is the same as the order of
Model messages maybe include the history messages, The order of the messages is
the same as the order of
the messages in the conversation, the last message is the latest message.
If you want to hand the message with your own logic, you can override this method.
If you want to hand the message with your own logic, you can override this
method.
Examples:
If you not need the history messages, you can override this method like this:
If you not need the history messages, you can override this method
like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
@@ -681,7 +735,8 @@ class OnceConversation:
)
return messages
If you want to add the one round history messages, you can override this method like this:
If you want to add the one round history messages, you can override this
method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
@@ -717,7 +772,7 @@ class OnceConversation:
def get_history_message(
self, include_system_message: bool = False
) -> List[BaseMessage]:
"""Get the history message
"""Get the history message.
Not include the system messages.
@@ -729,46 +784,60 @@ class OnceConversation:
"""
messages = []
for message in self.messages:
if message.pass_to_model:
if include_system_message:
messages.append(message)
elif message.type != "system":
messages.append(message)
if (
message.pass_to_model
and include_system_message
or message.type != "system"
):
messages.append(message)
return messages
class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier"""
"""Conversation identifier."""
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
"""Create a conversation identifier.
Args:
conv_uid (str): The conversation uid
identifier_type (str): The identifier type
"""
self.conv_uid = conv_uid
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
"""Return the str identifier."""
return f"{self.identifier_type}:{self.conv_uid}"
def to_dict(self) -> Dict:
"""Convert to dict."""
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
class MessageIdentifier(ResourceIdentifier):
"""Message identifier"""
"""Message identifier."""
identifier_split = "___"
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
"""Create a message identifier."""
self.conv_uid = conv_uid
self.index = index
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
"""Return the str identifier."""
return (
f"{self.identifier_type}{self.identifier_split}{self.conv_uid}"
f"{self.identifier_split}{self.index}"
)
@staticmethod
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
"""Convert from str identifier
"""Convert from str identifier.
Args:
str_identifier (str): The str identifier
@@ -782,6 +851,7 @@ class MessageIdentifier(ResourceIdentifier):
return MessageIdentifier(parts[1], int(parts[2]))
def to_dict(self) -> Dict:
"""Convert to dict."""
return {
"conv_uid": self.conv_uid,
"index": self.index,
@@ -790,17 +860,31 @@ class MessageIdentifier(ResourceIdentifier):
class MessageStorageItem(StorageItem):
"""The message storage item.
Keep the message detail and the message index.
"""
@property
def identifier(self) -> MessageIdentifier:
"""Return the identifier."""
return self._id
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
"""Create a message storage item.
Args:
conv_uid (str): The conversation uid
index (int): The message index
message_detail (Dict): The message detail
"""
self.conv_uid = conv_uid
self.index = index
self.message_detail = message_detail
self._id = MessageIdentifier(conv_uid, index)
def to_dict(self) -> Dict:
"""Convert to dict."""
return {
"conv_uid": self.conv_uid,
"index": self.index,
@@ -808,7 +892,8 @@ class MessageStorageItem(StorageItem):
}
def to_message(self) -> BaseMessage:
"""Convert to message object
"""Convert to message object.
Returns:
BaseMessage: The message object
@@ -818,7 +903,7 @@ class MessageStorageItem(StorageItem):
return _message_from_dict(self.message_detail)
def merge(self, other: "StorageItem") -> None:
"""Merge the other message to self
"""Merge the other message to self.
Args:
other (StorageItem): The other message
@@ -829,16 +914,20 @@ class MessageStorageItem(StorageItem):
class StorageConversation(OnceConversation, StorageItem):
"""All the information of a conversation, the current single service in memory,
"""The storage conversation.
All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
@property
def identifier(self) -> ConversationIdentifier:
"""Return the identifier."""
return self._id
def to_dict(self) -> Dict:
"""Convert to dict."""
dict_data = self._to_dict()
messages: Dict = dict_data.pop("messages")
message_ids = []
@@ -859,7 +948,7 @@ class StorageConversation(OnceConversation, StorageItem):
return dict_data
def merge(self, other: "StorageItem") -> None:
"""Merge the other conversation to self
"""Merge the other conversation to self.
Args:
other (StorageItem): The other conversation
@@ -871,17 +960,18 @@ class StorageConversation(OnceConversation, StorageItem):
def __init__(
self,
conv_uid: str,
chat_mode: str = None,
user_name: str = None,
sys_code: str = None,
message_ids: List[str] = None,
summary: str = None,
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
chat_mode: str = "chat_normal",
user_name: Optional[str] = None,
sys_code: Optional[str] = None,
message_ids: Optional[List[str]] = None,
summary: Optional[str] = None,
save_message_independent: bool = True,
conv_storage: Optional[StorageInterface] = None,
message_storage: Optional[StorageInterface] = None,
load_message: bool = True,
**kwargs,
):
"""Create a conversation."""
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
self.conv_uid = conv_uid
self._message_ids = message_ids
@@ -905,7 +995,7 @@ class StorageConversation(OnceConversation, StorageItem):
@property
def message_ids(self) -> List[str]:
"""Get the message ids
"""Return the message ids.
Returns:
List[str]: The message ids
@@ -913,7 +1003,7 @@ class StorageConversation(OnceConversation, StorageItem):
return self._message_ids if self._message_ids else []
def end_current_round(self) -> None:
"""End the current round of conversation
"""End the current round of conversation.
Save the conversation to the storage after a round of conversation
"""
@@ -926,7 +1016,7 @@ class StorageConversation(OnceConversation, StorageItem):
]
def save_to_storage(self) -> None:
"""Save the conversation to the storage"""
"""Save the conversation to the storage."""
# Save messages first
message_list = self._get_message_items()
self._message_ids = [
@@ -943,7 +1033,7 @@ class StorageConversation(OnceConversation, StorageItem):
def load_from_storage(
self, conv_storage: StorageInterface, message_storage: StorageInterface
) -> None:
"""Load the conversation from the storage
"""Load the conversation from the storage.
Warning: This will overwrite the current conversation.
@@ -952,7 +1042,7 @@ class StorageConversation(OnceConversation, StorageItem):
message_storage (StorageInterface): The storage interface
"""
# Load conversation first
conversation: StorageConversation = conv_storage.load(
conversation: Optional[StorageConversation] = conv_storage.load(
self._id, StorageConversation
)
if conversation is None:
@@ -988,18 +1078,18 @@ class StorageConversation(OnceConversation, StorageItem):
def _append_additional_kwargs(
self, conversation: StorageConversation, messages: List[BaseMessage]
) -> None:
"""Parse the additional kwargs and append to the conversation
"""Parse the additional kwargs and append to the conversation.
Args:
conversation (StorageConversation): The conversation
messages (List[BaseMessage]): The messages
"""
param_type = None
param_value = None
param_type = ""
param_value = ""
for message in messages[::-1]:
if message.additional_kwargs:
param_type = message.additional_kwargs.get("param_type")
param_value = message.additional_kwargs.get("param_value")
param_type = message.additional_kwargs.get("param_type", "")
param_value = message.additional_kwargs.get("param_value", "")
break
if not conversation.param_type:
conversation.param_type = param_type
@@ -1007,7 +1097,7 @@ class StorageConversation(OnceConversation, StorageItem):
conversation.param_value = param_value
def delete(self) -> None:
"""Delete all the messages and conversation from the storage"""
"""Delete all the messages and conversation."""
# Delete messages first
message_list = self._get_message_items()
message_ids = [message.identifier for message in message_list]
@@ -1055,13 +1145,13 @@ def _conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
def _conversation_from_dict(once: dict) -> OnceConversation:
conversation = OnceConversation(
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
once.get("chat_mode", ""), once.get("user_name"), once.get("sys_code")
)
conversation.cost = once.get("cost", 0)
conversation.chat_mode = once.get("chat_mode", "chat_normal")
conversation.tokens = once.get("tokens", 0)
conversation.start_date = once.get("start_date", "")
conversation.chat_order = int(once.get("chat_order"))
conversation.chat_order = int(once.get("chat_order", 0))
conversation.param_type = once.get("param_type", "")
conversation.param_value = once.get("param_value", "")
conversation.model_name = once.get("model_name", "proxyllm")
@@ -1093,7 +1183,8 @@ def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessa
def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
"""Append the view message to the messages
"""Append the view message to the messages.
Just for show in DB-GPT-Web.
If already have view message, do nothing.