mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-23 10:20:01 +00:00
chore: Add pylint for DB-GPT core lib (#1076)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
"""The conversation and message module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
from dbgpt.core.interface.storage import (
|
||||
@@ -29,11 +31,11 @@ class BaseMessage(BaseModel, ABC):
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model"""
|
||||
"""Whether the message will be passed to the model."""
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict
|
||||
"""Convert to dict.
|
||||
|
||||
Returns:
|
||||
Dict: The dict object
|
||||
@@ -47,7 +49,7 @@ class BaseMessage(BaseModel, ABC):
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["BaseMessage"]) -> str:
|
||||
"""Convert messages to str
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The messages
|
||||
@@ -92,7 +94,7 @@ class ViewMessage(BaseMessage):
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model
|
||||
"""Whether the message will be passed to the model.
|
||||
|
||||
The view message will not be passed to the model
|
||||
"""
|
||||
@@ -109,7 +111,7 @@ class SystemMessage(BaseMessage):
|
||||
|
||||
|
||||
class ModelMessageRoleType:
|
||||
""" "Type of ModelMessage role"""
|
||||
"""Type of ModelMessage role."""
|
||||
|
||||
SYSTEM = "system"
|
||||
HUMAN = "human"
|
||||
@@ -118,7 +120,7 @@ class ModelMessageRoleType:
|
||||
|
||||
|
||||
class ModelMessage(BaseModel):
|
||||
"""Type of message that interaction between dbgpt-server and llm-server"""
|
||||
"""Type of message that interaction between dbgpt-server and llm-server."""
|
||||
|
||||
"""Similar to openai's message format"""
|
||||
role: str
|
||||
@@ -127,7 +129,7 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model
|
||||
"""Whether the message will be passed to the model.
|
||||
|
||||
The view message will not be passed to the model
|
||||
|
||||
@@ -142,6 +144,14 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_base_messages(messages: List[BaseMessage]) -> List["ModelMessage"]:
|
||||
"""Covert BaseMessage format to current ModelMessage format.
|
||||
|
||||
Args:
|
||||
messages (List[BaseMessage]): The base messages
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The model messages
|
||||
"""
|
||||
result = []
|
||||
for message in messages:
|
||||
content, round_index = message.content, message.round_index
|
||||
@@ -173,7 +183,7 @@ class ModelMessage(BaseModel):
|
||||
def from_openai_messages(
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List["ModelMessage"]:
|
||||
"""Openai message format to current ModelMessage format"""
|
||||
"""Openai message format to current ModelMessage format."""
|
||||
if isinstance(messages, str):
|
||||
return [ModelMessage(role=ModelMessageRoleType.HUMAN, content=messages)]
|
||||
result = []
|
||||
@@ -202,8 +212,11 @@ class ModelMessage(BaseModel):
|
||||
convert_to_compatible_format: bool = False,
|
||||
support_system_role: bool = True,
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Convert to common message format(e.g. OpenAI message format) and
|
||||
huggingface [Templates of Chat Models](https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
"""Cover to common message format.
|
||||
|
||||
Convert to common message format(e.g. OpenAI message format) and
|
||||
huggingface [Templates of Chat Models]
|
||||
(https://huggingface.co/docs/transformers/v4.34.1/en/chat_templating)
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
@@ -243,15 +256,38 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def to_dict_list(messages: List["ModelMessage"]) -> List[Dict[str, str]]:
|
||||
"""Convert to dict list.
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: The dict list
|
||||
"""
|
||||
return list(map(lambda m: m.dict(), messages))
|
||||
|
||||
@staticmethod
|
||||
def build_human_message(content: str) -> "ModelMessage":
|
||||
"""Build human message.
|
||||
|
||||
Args:
|
||||
content (str): The content
|
||||
|
||||
Returns:
|
||||
ModelMessage: The model message
|
||||
"""
|
||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
|
||||
@staticmethod
|
||||
def get_printable_message(messages: List["ModelMessage"]) -> str:
|
||||
"""Get the printable message"""
|
||||
"""Get the printable message.
|
||||
|
||||
Args:
|
||||
messages (List["ModelMessage"]): The model messages
|
||||
|
||||
Returns:
|
||||
str: The printable message
|
||||
"""
|
||||
str_msg = ""
|
||||
for message in messages:
|
||||
curr_message = (
|
||||
@@ -263,7 +299,7 @@ class ModelMessage(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def messages_to_string(messages: List["ModelMessage"]) -> str:
|
||||
"""Convert messages to str
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): The messages
|
||||
@@ -287,12 +323,12 @@ def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
|
||||
|
||||
|
||||
def _messages_to_str(
|
||||
messages: List[Union[BaseMessage, ModelMessage]],
|
||||
messages: Union[List[BaseMessage], List[ModelMessage]],
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "AI",
|
||||
system_prefix: str = "System",
|
||||
) -> str:
|
||||
"""Convert messages to str
|
||||
"""Convert messages to str.
|
||||
|
||||
Args:
|
||||
messages (List[Union[BaseMessage, ModelMessage]]): The messages
|
||||
@@ -343,21 +379,27 @@ def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
|
||||
|
||||
def parse_model_messages(
|
||||
messages: List[ModelMessage],
|
||||
) -> Tuple[str, List[str], List[List[str, str]]]:
|
||||
"""
|
||||
Parse model messages to extract the user prompt, system messages, and a history of conversation.
|
||||
) -> Tuple[str, List[str], List[List[str]]]:
|
||||
"""Parse model messages.
|
||||
|
||||
This function analyzes a list of ModelMessage objects, identifying the role of each message (e.g., human, system, ai)
|
||||
and categorizes them accordingly. The last message is expected to be from the user (human), and it's treated as
|
||||
the current user prompt. System messages are extracted separately, and the conversation history is compiled into
|
||||
pairs of human and AI messages.
|
||||
Parse model messages to extract the user prompt, system messages, and a history of
|
||||
conversation.
|
||||
|
||||
This function analyzes a list of ModelMessage objects, identifying the role of each
|
||||
message (e.g., human, system, ai)
|
||||
and categorizes them accordingly. The last message is expected to be from the user
|
||||
(human), and it's treated as
|
||||
the current user prompt. System messages are extracted separately, and the
|
||||
conversation history is compiled into pairs of human and AI messages.
|
||||
|
||||
Args:
|
||||
messages (List[ModelMessage]): List of messages from a chat conversation.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the user prompt, list of system messages, and the conversation history.
|
||||
The conversation history is a list of message pairs, each containing a user message and the corresponding AI response.
|
||||
tuple: A tuple containing the user prompt, list of system messages, and the
|
||||
conversation history.
|
||||
The conversation history is a list of message pairs, each containing a
|
||||
user message and the corresponding AI response.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -399,7 +441,6 @@ def parse_model_messages(
|
||||
# system_messages: ["Error 404"]
|
||||
# history: [["Hi", "Hello!"], ["What's the error?", "Just a joke."]]
|
||||
"""
|
||||
|
||||
system_messages: List[str] = []
|
||||
history_messages: List[List[str]] = [[]]
|
||||
|
||||
@@ -420,27 +461,30 @@ def parse_model_messages(
|
||||
|
||||
|
||||
class OnceConversation:
|
||||
"""All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
"""Once conversation.
|
||||
|
||||
All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode: str,
|
||||
user_name: str = None,
|
||||
sys_code: str = None,
|
||||
summary: str = None,
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
summary: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new conversation."""
|
||||
self.chat_mode: str = chat_mode
|
||||
self.user_name: str = user_name
|
||||
self.sys_code: str = sys_code
|
||||
self.summary: str = summary
|
||||
self.user_name: Optional[str] = user_name
|
||||
self.sys_code: Optional[str] = sys_code
|
||||
self.summary: Optional[str] = summary
|
||||
|
||||
self.messages: List[BaseMessage] = kwargs.get("messages", [])
|
||||
self.start_date: str = kwargs.get("start_date", "")
|
||||
# After each complete round of dialogue, the current value will be increased by 1
|
||||
# After each complete round of dialogue, the current value will be
|
||||
# increased by 1
|
||||
self.chat_order: int = int(kwargs.get("chat_order", 0))
|
||||
self.model_name: str = kwargs.get("model_name", "")
|
||||
self.param_type: str = kwargs.get("param_type", "")
|
||||
@@ -460,10 +504,9 @@ class OnceConversation:
|
||||
self.messages.append(message)
|
||||
|
||||
def start_new_round(self) -> None:
|
||||
"""Start a new round of conversation
|
||||
"""Start a new round of conversation.
|
||||
|
||||
Example:
|
||||
|
||||
>>> conversation = OnceConversation("chat_normal")
|
||||
>>> # The chat order will be 0, then we start a new round of conversation
|
||||
>>> assert conversation.chat_order == 0
|
||||
@@ -473,7 +516,8 @@ class OnceConversation:
|
||||
>>> conversation.add_user_message("hello")
|
||||
>>> conversation.add_ai_message("hi")
|
||||
>>> conversation.end_current_round()
|
||||
>>> # Now the chat order will be 1, then we start a new round of conversation
|
||||
>>> # Now the chat order will be 1, then we start a new round of
|
||||
>>> # conversation
|
||||
>>> conversation.start_new_round()
|
||||
>>> # Now the chat order will be 2
|
||||
>>> assert conversation.chat_order == 2
|
||||
@@ -485,7 +529,7 @@ class OnceConversation:
|
||||
self.chat_order += 1
|
||||
|
||||
def end_current_round(self) -> None:
|
||||
"""End the current round of conversation
|
||||
"""Execute the end of the current round of conversation.
|
||||
|
||||
We do noting here, just for the interface
|
||||
"""
|
||||
@@ -494,7 +538,7 @@ class OnceConversation:
|
||||
def add_user_message(
|
||||
self, message: str, check_duplicate_type: Optional[bool] = False
|
||||
) -> None:
|
||||
"""Add a user message to the conversation
|
||||
"""Save a user message to the conversation.
|
||||
|
||||
Args:
|
||||
message (str): The message content
|
||||
@@ -514,11 +558,12 @@ class OnceConversation:
|
||||
def add_ai_message(
|
||||
self, message: str, update_if_exist: Optional[bool] = False
|
||||
) -> None:
|
||||
"""Add an AI message to the conversation
|
||||
"""Save an AI message to current conversation.
|
||||
|
||||
Args:
|
||||
message (str): The message content
|
||||
update_if_exist (bool): Whether to update the message if the message type is duplicate
|
||||
update_if_exist (bool): Whether to update the message if the message type
|
||||
is duplicate
|
||||
"""
|
||||
if not update_if_exist:
|
||||
self._append_message(AIMessage(content=message))
|
||||
@@ -530,51 +575,57 @@ class OnceConversation:
|
||||
self._append_message(AIMessage(content=message))
|
||||
|
||||
def _update_ai_message(self, new_message: str) -> None:
|
||||
"""
|
||||
"""Update the all AI message to new message.
|
||||
|
||||
stream out message update
|
||||
|
||||
Args:
|
||||
new_message:
|
||||
|
||||
Returns:
|
||||
|
||||
new_message (str): The new message
|
||||
"""
|
||||
|
||||
for item in self.messages:
|
||||
if item.type == "ai":
|
||||
item.content = new_message
|
||||
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
"""Save a view message to current conversation."""
|
||||
self._append_message(ViewMessage(content=message))
|
||||
|
||||
def add_system_message(self, message: str) -> None:
|
||||
"""Add a system message to the store"""
|
||||
"""Save a system message to current conversation."""
|
||||
self._append_message(SystemMessage(content=message))
|
||||
|
||||
def set_start_time(self, datatime: datetime):
|
||||
"""Set the start time of the conversation."""
|
||||
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.start_date = dt_str
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
"""Remove all messages from the store."""
|
||||
self.messages.clear()
|
||||
|
||||
def get_latest_user_message(self) -> Optional[HumanMessage]:
|
||||
"""Get the latest user message"""
|
||||
"""Get the latest user message."""
|
||||
for message in self.messages[::-1]:
|
||||
if isinstance(message, HumanMessage):
|
||||
return message
|
||||
return None
|
||||
|
||||
def get_system_messages(self) -> List[SystemMessage]:
|
||||
"""Get the latest user message"""
|
||||
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
|
||||
"""Get the latest user message.
|
||||
|
||||
Returns:
|
||||
List[SystemMessage]: The system messages
|
||||
"""
|
||||
return cast(
|
||||
List[SystemMessage],
|
||||
list(filter(lambda x: isinstance(x, SystemMessage), self.messages)),
|
||||
)
|
||||
|
||||
def _to_dict(self) -> Dict:
|
||||
return _conversation_to_dict(self)
|
||||
|
||||
def from_conversation(self, conversation: OnceConversation) -> None:
|
||||
"""Load the conversation from the storage"""
|
||||
"""Load the conversation from the storage."""
|
||||
self.chat_mode = conversation.chat_mode
|
||||
self.messages = conversation.messages
|
||||
self.start_date = conversation.start_date
|
||||
@@ -592,7 +643,7 @@ class OnceConversation:
|
||||
self._message_index = conversation._message_index
|
||||
|
||||
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||
"""Get the messages by round index
|
||||
"""Get the messages by round index.
|
||||
|
||||
Args:
|
||||
round_index (int): The round index
|
||||
@@ -603,7 +654,7 @@ class OnceConversation:
|
||||
return list(filter(lambda x: x.round_index == round_index, self.messages))
|
||||
|
||||
def get_latest_round(self) -> List[BaseMessage]:
|
||||
"""Get the latest round messages
|
||||
"""Get the latest round messages.
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages
|
||||
@@ -611,7 +662,7 @@ class OnceConversation:
|
||||
return self.get_messages_by_round(self.chat_order)
|
||||
|
||||
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
|
||||
"""Get the messages with round count
|
||||
"""Get the messages with round count.
|
||||
|
||||
If the round count is 1, the history messages will not be included.
|
||||
|
||||
@@ -660,16 +711,19 @@ class OnceConversation:
|
||||
return messages
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
"""Get the model messages
|
||||
"""Get the model messages.
|
||||
|
||||
Model messages just include human, ai and system messages.
|
||||
Model messages maybe include the history messages, The order of the messages is the same as the order of
|
||||
Model messages maybe include the history messages, The order of the messages is
|
||||
the same as the order of
|
||||
the messages in the conversation, the last message is the latest message.
|
||||
|
||||
If you want to hand the message with your own logic, you can override this method.
|
||||
If you want to hand the message with your own logic, you can override this
|
||||
method.
|
||||
|
||||
Examples:
|
||||
If you not need the history messages, you can override this method like this:
|
||||
If you not need the history messages, you can override this method
|
||||
like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
@@ -681,7 +735,8 @@ class OnceConversation:
|
||||
)
|
||||
return messages
|
||||
|
||||
If you want to add the one round history messages, you can override this method like this:
|
||||
If you want to add the one round history messages, you can override this
|
||||
method like this:
|
||||
.. code-block:: python
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
@@ -717,7 +772,7 @@ class OnceConversation:
|
||||
def get_history_message(
|
||||
self, include_system_message: bool = False
|
||||
) -> List[BaseMessage]:
|
||||
"""Get the history message
|
||||
"""Get the history message.
|
||||
|
||||
Not include the system messages.
|
||||
|
||||
@@ -729,46 +784,60 @@ class OnceConversation:
|
||||
"""
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if message.pass_to_model:
|
||||
if include_system_message:
|
||||
messages.append(message)
|
||||
elif message.type != "system":
|
||||
messages.append(message)
|
||||
if (
|
||||
message.pass_to_model
|
||||
and include_system_message
|
||||
or message.type != "system"
|
||||
):
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
class ConversationIdentifier(ResourceIdentifier):
|
||||
"""Conversation identifier"""
|
||||
"""Conversation identifier."""
|
||||
|
||||
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
|
||||
"""Create a conversation identifier.
|
||||
|
||||
Args:
|
||||
conv_uid (str): The conversation uid
|
||||
identifier_type (str): The identifier type
|
||||
"""
|
||||
self.conv_uid = conv_uid
|
||||
self.identifier_type = identifier_type
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
"""Return the str identifier."""
|
||||
return f"{self.identifier_type}:{self.conv_uid}"
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
|
||||
|
||||
|
||||
class MessageIdentifier(ResourceIdentifier):
|
||||
"""Message identifier"""
|
||||
"""Message identifier."""
|
||||
|
||||
identifier_split = "___"
|
||||
|
||||
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
|
||||
"""Create a message identifier."""
|
||||
self.conv_uid = conv_uid
|
||||
self.index = index
|
||||
self.identifier_type = identifier_type
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
|
||||
"""Return the str identifier."""
|
||||
return (
|
||||
f"{self.identifier_type}{self.identifier_split}{self.conv_uid}"
|
||||
f"{self.identifier_split}{self.index}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
|
||||
"""Convert from str identifier
|
||||
"""Convert from str identifier.
|
||||
|
||||
Args:
|
||||
str_identifier (str): The str identifier
|
||||
@@ -782,6 +851,7 @@ class MessageIdentifier(ResourceIdentifier):
|
||||
return MessageIdentifier(parts[1], int(parts[2]))
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"index": self.index,
|
||||
@@ -790,17 +860,31 @@ class MessageIdentifier(ResourceIdentifier):
|
||||
|
||||
|
||||
class MessageStorageItem(StorageItem):
|
||||
"""The message storage item.
|
||||
|
||||
Keep the message detail and the message index.
|
||||
"""
|
||||
|
||||
@property
|
||||
def identifier(self) -> MessageIdentifier:
|
||||
"""Return the identifier."""
|
||||
return self._id
|
||||
|
||||
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
|
||||
"""Create a message storage item.
|
||||
|
||||
Args:
|
||||
conv_uid (str): The conversation uid
|
||||
index (int): The message index
|
||||
message_detail (Dict): The message detail
|
||||
"""
|
||||
self.conv_uid = conv_uid
|
||||
self.index = index
|
||||
self.message_detail = message_detail
|
||||
self._id = MessageIdentifier(conv_uid, index)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"index": self.index,
|
||||
@@ -808,7 +892,8 @@ class MessageStorageItem(StorageItem):
|
||||
}
|
||||
|
||||
def to_message(self) -> BaseMessage:
|
||||
"""Convert to message object
|
||||
"""Convert to message object.
|
||||
|
||||
Returns:
|
||||
BaseMessage: The message object
|
||||
|
||||
@@ -818,7 +903,7 @@ class MessageStorageItem(StorageItem):
|
||||
return _message_from_dict(self.message_detail)
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other message to self
|
||||
"""Merge the other message to self.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other message
|
||||
@@ -829,16 +914,20 @@ class MessageStorageItem(StorageItem):
|
||||
|
||||
|
||||
class StorageConversation(OnceConversation, StorageItem):
|
||||
"""All the information of a conversation, the current single service in memory,
|
||||
"""The storage conversation.
|
||||
|
||||
All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def identifier(self) -> ConversationIdentifier:
|
||||
"""Return the identifier."""
|
||||
return self._id
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
dict_data = self._to_dict()
|
||||
messages: Dict = dict_data.pop("messages")
|
||||
message_ids = []
|
||||
@@ -859,7 +948,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
return dict_data
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other conversation to self
|
||||
"""Merge the other conversation to self.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other conversation
|
||||
@@ -871,17 +960,18 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
def __init__(
|
||||
self,
|
||||
conv_uid: str,
|
||||
chat_mode: str = None,
|
||||
user_name: str = None,
|
||||
sys_code: str = None,
|
||||
message_ids: List[str] = None,
|
||||
summary: str = None,
|
||||
save_message_independent: Optional[bool] = True,
|
||||
conv_storage: StorageInterface = None,
|
||||
message_storage: StorageInterface = None,
|
||||
chat_mode: str = "chat_normal",
|
||||
user_name: Optional[str] = None,
|
||||
sys_code: Optional[str] = None,
|
||||
message_ids: Optional[List[str]] = None,
|
||||
summary: Optional[str] = None,
|
||||
save_message_independent: bool = True,
|
||||
conv_storage: Optional[StorageInterface] = None,
|
||||
message_storage: Optional[StorageInterface] = None,
|
||||
load_message: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a conversation."""
|
||||
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
|
||||
self.conv_uid = conv_uid
|
||||
self._message_ids = message_ids
|
||||
@@ -905,7 +995,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
|
||||
@property
|
||||
def message_ids(self) -> List[str]:
|
||||
"""Get the message ids
|
||||
"""Return the message ids.
|
||||
|
||||
Returns:
|
||||
List[str]: The message ids
|
||||
@@ -913,7 +1003,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
return self._message_ids if self._message_ids else []
|
||||
|
||||
def end_current_round(self) -> None:
|
||||
"""End the current round of conversation
|
||||
"""End the current round of conversation.
|
||||
|
||||
Save the conversation to the storage after a round of conversation
|
||||
"""
|
||||
@@ -926,7 +1016,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
]
|
||||
|
||||
def save_to_storage(self) -> None:
|
||||
"""Save the conversation to the storage"""
|
||||
"""Save the conversation to the storage."""
|
||||
# Save messages first
|
||||
message_list = self._get_message_items()
|
||||
self._message_ids = [
|
||||
@@ -943,7 +1033,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
def load_from_storage(
|
||||
self, conv_storage: StorageInterface, message_storage: StorageInterface
|
||||
) -> None:
|
||||
"""Load the conversation from the storage
|
||||
"""Load the conversation from the storage.
|
||||
|
||||
Warning: This will overwrite the current conversation.
|
||||
|
||||
@@ -952,7 +1042,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
message_storage (StorageInterface): The storage interface
|
||||
"""
|
||||
# Load conversation first
|
||||
conversation: StorageConversation = conv_storage.load(
|
||||
conversation: Optional[StorageConversation] = conv_storage.load(
|
||||
self._id, StorageConversation
|
||||
)
|
||||
if conversation is None:
|
||||
@@ -988,18 +1078,18 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
def _append_additional_kwargs(
|
||||
self, conversation: StorageConversation, messages: List[BaseMessage]
|
||||
) -> None:
|
||||
"""Parse the additional kwargs and append to the conversation
|
||||
"""Parse the additional kwargs and append to the conversation.
|
||||
|
||||
Args:
|
||||
conversation (StorageConversation): The conversation
|
||||
messages (List[BaseMessage]): The messages
|
||||
"""
|
||||
param_type = None
|
||||
param_value = None
|
||||
param_type = ""
|
||||
param_value = ""
|
||||
for message in messages[::-1]:
|
||||
if message.additional_kwargs:
|
||||
param_type = message.additional_kwargs.get("param_type")
|
||||
param_value = message.additional_kwargs.get("param_value")
|
||||
param_type = message.additional_kwargs.get("param_type", "")
|
||||
param_value = message.additional_kwargs.get("param_value", "")
|
||||
break
|
||||
if not conversation.param_type:
|
||||
conversation.param_type = param_type
|
||||
@@ -1007,7 +1097,7 @@ class StorageConversation(OnceConversation, StorageItem):
|
||||
conversation.param_value = param_value
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete all the messages and conversation from the storage"""
|
||||
"""Delete all the messages and conversation."""
|
||||
# Delete messages first
|
||||
message_list = self._get_message_items()
|
||||
message_ids = [message.identifier for message in message_list]
|
||||
@@ -1055,13 +1145,13 @@ def _conversations_to_dict(conversations: List[OnceConversation]) -> List[dict]:
|
||||
|
||||
def _conversation_from_dict(once: dict) -> OnceConversation:
|
||||
conversation = OnceConversation(
|
||||
once.get("chat_mode"), once.get("user_name"), once.get("sys_code")
|
||||
once.get("chat_mode", ""), once.get("user_name"), once.get("sys_code")
|
||||
)
|
||||
conversation.cost = once.get("cost", 0)
|
||||
conversation.chat_mode = once.get("chat_mode", "chat_normal")
|
||||
conversation.tokens = once.get("tokens", 0)
|
||||
conversation.start_date = once.get("start_date", "")
|
||||
conversation.chat_order = int(once.get("chat_order"))
|
||||
conversation.chat_order = int(once.get("chat_order", 0))
|
||||
conversation.param_type = once.get("param_type", "")
|
||||
conversation.param_value = once.get("param_value", "")
|
||||
conversation.model_name = once.get("model_name", "proxyllm")
|
||||
@@ -1093,7 +1183,8 @@ def _split_messages_by_round(messages: List[BaseMessage]) -> List[List[BaseMessa
|
||||
|
||||
|
||||
def _append_view_messages(messages: List[BaseMessage]) -> List[BaseMessage]:
|
||||
"""Append the view message to the messages
|
||||
"""Append the view message to the messages.
|
||||
|
||||
Just for show in DB-GPT-Web.
|
||||
If already have view message, do nothing.
|
||||
|
||||
|
Reference in New Issue
Block a user