refactor: Refactor storage system (#937)

This commit is contained in:
Fangyin Cheng
2023-12-15 16:35:45 +08:00
committed by GitHub
parent a1e415d68d
commit aed1c3fb2b
55 changed files with 3780 additions and 680 deletions

View File

@@ -9,6 +9,10 @@ from dbgpt.core.interface.message import (
ModelMessage,
ModelMessageRoleType,
OnceConversation,
StorageConversation,
MessageStorageItem,
ConversationIdentifier,
MessageIdentifier,
)
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
@@ -20,6 +24,16 @@ from dbgpt.core.interface.cache import (
CachePolicy,
CacheConfig,
)
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageItem,
StorageItemAdapter,
StorageInterface,
InMemoryStorage,
DefaultStorageItemAdapter,
QuerySpec,
StorageError,
)
__ALL__ = [
"ModelInferenceMetrics",
@@ -30,6 +44,10 @@ __ALL__ = [
"ModelMessage",
"ModelMessageRoleType",
"OnceConversation",
"StorageConversation",
"MessageStorageItem",
"ConversationIdentifier",
"MessageIdentifier",
"PromptTemplate",
"PromptTemplateOperator",
"BaseOutputParser",
@@ -41,4 +59,12 @@ __ALL__ = [
"CacheClient",
"CachePolicy",
"CacheConfig",
"ResourceIdentifier",
"StorageItem",
"StorageItemAdapter",
"StorageInterface",
"InMemoryStorage",
"DefaultStorageItemAdapter",
"QuerySpec",
"StorageError",
]

View File

@@ -1,7 +1,7 @@
import pytest
import threading
import asyncio
from ..dag import DAG, DAGContext
from ..base import DAG, DAGVar
def test_dag_context_sync():
@@ -9,18 +9,18 @@ def test_dag_context_sync():
dag2 = DAG("dag2")
with dag1:
assert DAGContext.get_current_dag() == dag1
assert DAGVar.get_current_dag() == dag1
with dag2:
assert DAGContext.get_current_dag() == dag2
assert DAGContext.get_current_dag() == dag1
assert DAGContext.get_current_dag() is None
assert DAGVar.get_current_dag() == dag2
assert DAGVar.get_current_dag() == dag1
assert DAGVar.get_current_dag() is None
def test_dag_context_threading():
def thread_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
DAGVar.enter_dag(dag)
assert DAGVar.get_current_dag() == dag
DAGVar.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
@@ -33,19 +33,19 @@ def test_dag_context_threading():
thread1.join()
thread2.join()
assert DAGContext.get_current_dag() is None
assert DAGVar.get_current_dag() is None
@pytest.mark.asyncio
async def test_dag_context_async():
async def async_function(dag):
DAGContext.enter_dag(dag)
assert DAGContext.get_current_dag() == dag
DAGContext.exit_dag()
DAGVar.enter_dag(dag)
assert DAGVar.get_current_dag() == dag
DAGVar.exit_dag()
dag1 = DAG("dag1")
dag2 = DAG("dag2")
await asyncio.gather(async_function(dag1), async_function(dag2))
assert DAGContext.get_current_dag() is None
assert DAGVar.get_current_dag() is None

View File

@@ -1,16 +1,26 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Tuple, Union, Optional
from datetime import datetime
from dbgpt._private.pydantic import BaseModel, Field
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageItem,
StorageInterface,
InMemoryStorage,
)
class BaseMessage(BaseModel, ABC):
"""Message object."""
content: str
index: int = 0
round_index: int = 0
"""The round index of the message in the conversation"""
additional_kwargs: dict = Field(default_factory=dict)
@property
@@ -18,6 +28,24 @@ class BaseMessage(BaseModel, ABC):
def type(self) -> str:
"""Type of the message, used for serialization."""
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model"""
return True
def to_dict(self) -> Dict:
"""Convert to dict
Returns:
Dict: The dict object
"""
return {
"type": self.type,
"data": self.dict(),
"index": self.index,
"round_index": self.round_index,
}
class HumanMessage(BaseMessage):
"""Type of message that is spoken by the human."""
@@ -51,6 +79,14 @@ class ViewMessage(BaseMessage):
"""Type of the message, used for serialization."""
return "view"
@property
def pass_to_model(self) -> bool:
"""Whether the message will be passed to the model
The view message will not be passed to the model
"""
return False
class SystemMessage(BaseMessage):
"""Type of message that is a system message."""
@@ -141,15 +177,15 @@ class ModelMessage(BaseModel):
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}
def _message_to_dict(message: BaseMessage) -> Dict:
return message.to_dict()
def _messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
return [_message_to_dict(m) for m in messages]
def _message_from_dict(message: dict) -> BaseMessage:
def _message_from_dict(message: Dict) -> BaseMessage:
_type = message["type"]
if _type == "human":
return HumanMessage(**message["data"])
@@ -163,7 +199,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
raise ValueError(f"Got unexpected type: {_type}")
def _messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
return [_message_from_dict(m) for m in messages]
@@ -193,50 +229,119 @@ def _parse_model_messages(
history_messages.append([])
if messages[-1].role != "human":
raise ValueError("Hi! What do you want to talk about")
# Keep message pair of [user message, assistant message]
# Keep message a pair of [user message, assistant message]
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
user_prompt = messages[-1].content
return user_prompt, system_messages, history_messages
class OnceConversation:
"""
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
def __init__(
self,
chat_mode: str,
user_name: str = None,
sys_code: str = None,
summary: str = None,
**kwargs,
):
self.chat_mode: str = chat_mode
self.messages: List[BaseMessage] = []
self.start_date: str = ""
self.chat_order: int = 0
self.model_name: str = ""
self.param_type: str = ""
self.param_value: str = ""
self.cost: int = 0
self.tokens: int = 0
self.user_name: str = user_name
self.sys_code: str = sys_code
self.summary: str = summary
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
has_message = any(
isinstance(instance, HumanMessage) for instance in self.messages
)
if has_message:
raise ValueError("Already Have Human message")
self.messages.append(HumanMessage(content=message))
self.messages: List[BaseMessage] = kwargs.get("messages", [])
self.start_date: str = kwargs.get("start_date", "")
# After each complete round of dialogue, the current value will be increased by 1
self.chat_order: int = int(kwargs.get("chat_order", 0))
self.model_name: str = kwargs.get("model_name", "")
self.param_type: str = kwargs.get("param_type", "")
self.param_value: str = kwargs.get("param_value", "")
self.cost: int = int(kwargs.get("cost", 0))
self.tokens: int = int(kwargs.get("tokens", 0))
self._message_index: int = int(kwargs.get("message_index", 0))
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
def _append_message(self, message: BaseMessage) -> None:
index = self._message_index
self._message_index += 1
message.index = index
message.round_index = self.chat_order
self.messages.append(message)
def start_new_round(self) -> None:
"""Start a new round of conversation
Example:
>>> conversation = OnceConversation()
>>> # The chat order will be 0, then we start a new round of conversation
>>> assert conversation.chat_order == 0
>>> conversation.start_new_round()
>>> # Now the chat order will be 1
>>> assert conversation.chat_order == 1
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> # Now the chat order will be 1, then we start a new round of conversation
>>> conversation.start_new_round()
>>> # Now the chat order will be 2
>>> assert conversation.chat_order == 2
>>> conversation.add_user_message("hello")
>>> conversation.add_ai_message("hi")
>>> conversation.end_current_round()
>>> assert conversation.chat_order == 2
"""
self.chat_order += 1
def end_current_round(self) -> None:
"""End the current round of conversation
We do noting here, just for the interface
"""
pass
def add_user_message(
self, message: str, check_duplicate_type: Optional[bool] = False
) -> None:
"""Add a user message to the conversation
Args:
message (str): The message content
check_duplicate_type (bool): Whether to check the duplicate message type
Raises:
ValueError: If the message is duplicate and check_duplicate_type is True
"""
if check_duplicate_type:
has_message = any(
isinstance(instance, HumanMessage) for instance in self.messages
)
if has_message:
raise ValueError("Already Have Human message")
self._append_message(HumanMessage(content=message))
def add_ai_message(
self, message: str, update_if_exist: Optional[bool] = False
) -> None:
"""Add an AI message to the conversation
Args:
message (str): The message content
update_if_exist (bool): Whether to update the message if the message type is duplicate
"""
if not update_if_exist:
self._append_message(AIMessage(content=message))
return
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
if has_message:
self.__update_ai_message(message)
self._update_ai_message(message)
else:
self.messages.append(AIMessage(content=message))
""" """
self._append_message(AIMessage(content=message))
def __update_ai_message(self, new_message: str) -> None:
def _update_ai_message(self, new_message: str) -> None:
"""
stream out message update
Args:
@@ -252,13 +357,11 @@ class OnceConversation:
def add_view_message(self, message: str) -> None:
"""Add an AI message to the store"""
self.messages.append(ViewMessage(content=message))
""" """
self._append_message(ViewMessage(content=message))
def add_system_message(self, message: str) -> None:
"""Add an AI message to the store"""
self.messages.append(SystemMessage(content=message))
"""Add a system message to the store"""
self._append_message(SystemMessage(content=message))
def set_start_time(self, datatime: datetime):
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
@@ -267,23 +370,369 @@ class OnceConversation:
def clear(self) -> None:
"""Remove all messages from the store"""
self.messages.clear()
self.session_id = None
def get_user_conv(self):
for message in self.messages:
def get_latest_user_message(self) -> Optional[HumanMessage]:
"""Get the latest user message"""
for message in self.messages[::-1]:
if isinstance(message, HumanMessage):
return message
return None
def get_system_conv(self):
system_convs = []
def get_system_messages(self) -> List[SystemMessage]:
"""Get the latest user message"""
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
def _to_dict(self) -> Dict:
return _conversation_to_dict(self)
def from_conversation(self, conversation: OnceConversation) -> None:
"""Load the conversation from the storage"""
self.chat_mode = conversation.chat_mode
self.messages = conversation.messages
self.start_date = conversation.start_date
self.chat_order = conversation.chat_order
self.model_name = conversation.model_name
self.param_type = conversation.param_type
self.param_value = conversation.param_value
self.cost = conversation.cost
self.tokens = conversation.tokens
self.user_name = conversation.user_name
self.sys_code = conversation.sys_code
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
"""Get the messages by round index
Args:
round_index (int): The round index
Returns:
List[BaseMessage]: The messages
"""
return list(filter(lambda x: x.round_index == round_index, self.messages))
def get_latest_round(self) -> List[BaseMessage]:
"""Get the latest round messages
Returns:
List[BaseMessage]: The messages
"""
return self.get_messages_by_round(self.chat_order)
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
"""Get the messages with round count
If the round count is 1, the history messages will not be included.
Example:
.. code-block:: python
conversation = OnceConversation()
conversation.start_new_round()
conversation.add_user_message("hello, this is the first round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is the second round")
conversation.add_ai_message("hi")
conversation.end_current_round()
conversation.start_new_round()
conversation.add_user_message("hello, this is the third round")
conversation.add_ai_message("hi")
conversation.end_current_round()
assert len(conversation.get_messages_with_round(1)) == 2
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
assert conversation.get_messages_with_round(1)[1].content == "hi"
assert len(conversation.get_messages_with_round(2)) == 4
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
assert conversation.get_messages_with_round(2)[1].content == "hi"
Args:
round_count (int): The round count
Returns:
List[BaseMessage]: The messages
"""
latest_round_index = self.chat_order
start_round_index = max(1, latest_round_index - round_count + 1)
messages = []
for round_index in range(start_round_index, latest_round_index + 1):
messages.extend(self.get_messages_by_round(round_index))
return messages
def get_model_messages(self) -> List[ModelMessage]:
"""Get the model messages
Model messages just include human, ai and system messages.
Model messages maybe include the history messages, The order of the messages is the same as the order of
the messages in the conversation, the last message is the latest message.
If you want to hand the message with your own logic, you can override this method.
Examples:
If you not need the history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
for message in self.get_latest_round():
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
)
return messages
If you want to add the one round history messages, you can override this method like this:
.. code-block:: python
def get_model_messages(self) -> List[ModelMessage]:
messages = []
latest_round_index = self.chat_order
round_count = 1
start_round_index = max(1, latest_round_index - round_count + 1)
for round_index in range(start_round_index, latest_round_index + 1):
for message in self.get_messages_by_round(round_index):
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
)
return messages
Returns:
List[ModelMessage]: The model messages
"""
messages = []
for message in self.messages:
if isinstance(message, SystemMessage):
system_convs.append(message)
return system_convs
if message.pass_to_model:
messages.append(
ModelMessage(role=message.type, content=message.content)
)
return messages
def _conversation_to_dict(once: OnceConversation) -> dict:
class ConversationIdentifier(ResourceIdentifier):
"""Conversation identifier"""
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
self.conv_uid = conv_uid
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}:{self.conv_uid}"
def to_dict(self) -> Dict:
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
class MessageIdentifier(ResourceIdentifier):
"""Message identifier"""
identifier_split = "___"
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
self.conv_uid = conv_uid
self.index = index
self.identifier_type = identifier_type
@property
def str_identifier(self) -> str:
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
@staticmethod
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
"""Convert from str identifier
Args:
str_identifier (str): The str identifier
Returns:
MessageIdentifier: The message identifier
"""
parts = str_identifier.split(MessageIdentifier.identifier_split)
if len(parts) != 3:
raise ValueError(f"Invalid str identifier: {str_identifier}")
return MessageIdentifier(parts[1], int(parts[2]))
def to_dict(self) -> Dict:
return {
"conv_uid": self.conv_uid,
"index": self.index,
"identifier_type": self.identifier_type,
}
class MessageStorageItem(StorageItem):
@property
def identifier(self) -> MessageIdentifier:
return self._id
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
self.conv_uid = conv_uid
self.index = index
self.message_detail = message_detail
self._id = MessageIdentifier(conv_uid, index)
def to_dict(self) -> Dict:
return {
"conv_uid": self.conv_uid,
"index": self.index,
"message_detail": self.message_detail,
}
def to_message(self) -> BaseMessage:
"""Convert to message object
Returns:
BaseMessage: The message object
Raises:
ValueError: If the message type is not supported
"""
return _message_from_dict(self.message_detail)
def merge(self, other: "StorageItem") -> None:
"""Merge the other message to self
Args:
other (StorageItem): The other message
"""
if not isinstance(other, MessageStorageItem):
raise ValueError(f"Can not merge {other} to {self}")
self.message_detail = other.message_detail
class StorageConversation(OnceConversation, StorageItem):
"""All the information of a conversation, the current single service in memory,
can expand cache and database support distributed services.
"""
@property
def identifier(self) -> ConversationIdentifier:
return self._id
def to_dict(self) -> Dict:
dict_data = self._to_dict()
messages: Dict = dict_data.pop("messages")
message_ids = []
index = 0
for message in messages:
if "index" in message:
message_idx = message["index"]
else:
message_idx = index
index += 1
message_ids.append(
MessageIdentifier(self.conv_uid, message_idx).str_identifier
)
# Replace message with message ids
dict_data["conv_uid"] = self.conv_uid
dict_data["message_ids"] = message_ids
dict_data["save_message_independent"] = self.save_message_independent
return dict_data
def merge(self, other: "StorageItem") -> None:
"""Merge the other conversation to self
Args:
other (StorageItem): The other conversation
"""
if not isinstance(other, StorageConversation):
raise ValueError(f"Can not merge {other} to {self}")
self.from_conversation(other)
def __init__(
self,
conv_uid: str,
chat_mode: str = None,
user_name: str = None,
sys_code: str = None,
message_ids: List[str] = None,
summary: str = None,
save_message_independent: Optional[bool] = True,
conv_storage: StorageInterface = None,
message_storage: StorageInterface = None,
**kwargs,
):
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
self.conv_uid = conv_uid
self._message_ids = message_ids
self.save_message_independent = save_message_independent
self._id = ConversationIdentifier(conv_uid)
if conv_storage is None:
conv_storage = InMemoryStorage()
if message_storage is None:
message_storage = InMemoryStorage()
self.conv_storage = conv_storage
self.message_storage = message_storage
# Load from storage
self.load_from_storage(self.conv_storage, self.message_storage)
@property
def message_ids(self) -> List[str]:
"""Get the message ids
Returns:
List[str]: The message ids
"""
return self._message_ids if self._message_ids else []
def end_current_round(self) -> None:
"""End the current round of conversation
Save the conversation to the storage after a round of conversation
"""
self.save_to_storage()
def _get_message_items(self) -> List[MessageStorageItem]:
return [
MessageStorageItem(self.conv_uid, message.index, message.to_dict())
for message in self.messages
]
def save_to_storage(self) -> None:
"""Save the conversation to the storage"""
# Save messages first
message_list = self._get_message_items()
self._message_ids = [
message.identifier.str_identifier for message in message_list
]
self.message_storage.save_list(message_list)
# Save conversation
self.conv_storage.save_or_update(self)
def load_from_storage(
self, conv_storage: StorageInterface, message_storage: StorageInterface
) -> None:
"""Load the conversation from the storage
Warning: This will overwrite the current conversation.
Args:
conv_storage (StorageInterface): The storage interface
message_storage (StorageInterface): The storage interface
"""
# Load conversation first
conversation: StorageConversation = conv_storage.load(
self._id, StorageConversation
)
if conversation is None:
return
message_ids = conversation._message_ids or []
# Load messages
message_list = message_storage.load_list(
[
MessageIdentifier.from_str_identifier(message_id)
for message_id in message_ids
],
MessageStorageItem,
)
messages = [message.to_message() for message in message_list]
conversation.messages = messages
self._message_ids = message_ids
self.from_conversation(conversation)
def _conversation_to_dict(once: OnceConversation) -> Dict:
start_str: str = ""
if hasattr(once, "start_date") and once.start_date:
if isinstance(once.start_date, datetime):
@@ -303,6 +752,7 @@ def _conversation_to_dict(once: OnceConversation) -> dict:
"param_value": once.param_value,
"user_name": once.user_name,
"sys_code": once.sys_code,
"summary": once.summary if once.summary else "",
}

View File

@@ -92,7 +92,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
)
def __illegal_json_ends(self, s):
def _illegal_json_ends(self, s):
temp_json = s
illegal_json_ends_1 = [", }", ",}"]
illegal_json_ends_2 = ", ]", ",]"
@@ -102,25 +102,25 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
temp_json = temp_json.replace(illegal_json_end, " ]")
return temp_json
def __extract_json(self, s):
def _extract_json(self, s):
try:
# Get the dual-mode analysis first and get the maximum result
temp_json_simple = self.__json_interception(s)
temp_json_array = self.__json_interception(s, True)
temp_json_simple = self._json_interception(s)
temp_json_array = self._json_interception(s, True)
if len(temp_json_simple) > len(temp_json_array):
temp_json = temp_json_simple
else:
temp_json = temp_json_array
if not temp_json:
temp_json = self.__json_interception(s)
temp_json = self._json_interception(s)
temp_json = self.__illegal_json_ends(temp_json)
temp_json = self._illegal_json_ends(temp_json)
return temp_json
except Exception as e:
raise ValueError("Failed to find a valid json in LLM response" + temp_json)
def __json_interception(self, s, is_json_array: bool = False):
def _json_interception(self, s, is_json_array: bool = False):
try:
if is_json_array:
i = s.find("[")
@@ -176,7 +176,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
cleaned_output = cleaned_output.strip()
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
logger.info("illegal json processing:\n" + cleaned_output)
cleaned_output = self.__extract_json(cleaned_output)
cleaned_output = self._extract_json(cleaned_output)
if not cleaned_output or len(cleaned_output) <= 0:
return model_out_text
@@ -188,7 +188,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
.replace("\\", " ")
.replace("\_", "_")
)
cleaned_output = self.__illegal_json_ends(cleaned_output)
cleaned_output = self._illegal_json_ends(cleaned_output)
return cleaned_output
def parse_view_response(
@@ -208,20 +208,6 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
"""Instructions on how the LLM output should be formatted."""
raise NotImplementedError
# @property
# def _type(self) -> str:
# """Return the type key."""
# raise NotImplementedError(
# f"_type property is not implemented in class {self.__class__.__name__}."
# " This is required for serialization."
# )
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
output_parser_dict = super().dict()
output_parser_dict["_type"] = self._type
return output_parser_dict
async def map(self, input_value: ModelOutput) -> Any:
"""Parse the output of an LLM call.

View File

@@ -1,19 +1,34 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, Dict
class Serializable(ABC):
serializer: "Serializer" = None
@abstractmethod
def to_dict(self) -> Dict:
"""Convert the object's state to a dictionary."""
def serialize(self) -> bytes:
"""Convert the object into bytes for storage or transmission.
Returns:
bytes: The byte array after serialization
"""
if self.serializer is None:
raise ValueError(
"Serializer is not set. Please set the serializer before serialization."
)
return self.serializer.serialize(self)
@abstractmethod
def to_dict(self) -> Dict:
"""Convert the object's state to a dictionary."""
def set_serializer(self, serializer: "Serializer") -> None:
"""Set the serializer for current serializable object.
Args:
serializer (Serializer): The serializer to set
"""
self.serializer = serializer
class Serializer(ABC):

View File

@@ -0,0 +1,409 @@
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
from abc import ABC, abstractmethod
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.pagination_utils import PaginationResult
@PublicAPI(stability="beta")
class ResourceIdentifier(Serializable, ABC):
"""The resource identifier interface for resource identifiers."""
@property
@abstractmethod
def str_identifier(self) -> str:
"""Get the string identifier of the resource.
The string identifier is used to uniquely identify the resource.
Returns:
str: The string identifier of the resource
"""
def __hash__(self) -> int:
"""Return the hash value of the key."""
return hash(self.str_identifier)
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
if not isinstance(other, ResourceIdentifier):
return False
return self.str_identifier == other.str_identifier
@PublicAPI(stability="beta")
class StorageItem(Serializable, ABC):
"""The storage item interface for storage items."""
@property
@abstractmethod
def identifier(self) -> ResourceIdentifier:
"""Get the resource identifier of the storage item.
Returns:
ResourceIdentifier: The resource identifier of the storage item
"""
@abstractmethod
def merge(self, other: "StorageItem") -> None:
"""Merge the other storage item into the current storage item.
Args:
other (StorageItem): The other storage item
"""
T = TypeVar("T", bound=StorageItem)
TDataRepresentation = TypeVar("TDataRepresentation")
class StorageItemAdapter(Generic[T, TDataRepresentation]):
"""The storage item adapter for converting storage items to and from the storage format.
Sometimes, the storage item is not the same as the storage format,
so we need to convert the storage item to the storage format and vice versa.
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
"""
@abstractmethod
def to_storage_format(self, item: T) -> TDataRepresentation:
"""Convert the storage item to the storage format.
Args:
item (T): The storage item
Returns:
TDataRepresentation: The data in the storage format
"""
@abstractmethod
def from_storage_format(self, data: TDataRepresentation) -> T:
"""Convert the storage format to the storage item.
Args:
data (TDataRepresentation): The data in the storage format
Returns:
T: The storage item
"""
@abstractmethod
def get_query_for_identifier(
self,
storage_format: Type[TDataRepresentation],
resource_id: ResourceIdentifier,
**kwargs,
) -> Any:
"""Get the query for the resource identifier.
Args:
storage_format (Type[TDataRepresentation]): The storage format
resource_id (ResourceIdentifier): The resource identifier
kwargs: The additional arguments
Returns:
Any: The query for the resource identifier
"""
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
"""The default storage item adapter for converting storage items to and from the storage format.
The storage item is the same as the storage format, so no conversion is required.
"""
def to_storage_format(self, item: T) -> T:
return item
def from_storage_format(self, data: T) -> T:
return data
def get_query_for_identifier(
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
) -> bool:
return True
@PublicAPI(stability="beta")
class StorageError(Exception):
"""The base exception class for storage errors."""
def __init__(self, message: str):
super().__init__(message)
@PublicAPI(stability="beta")
class QuerySpec:
"""The query specification for querying data from the storage.
Attributes:
conditions (Dict[str, Any]): The conditions for querying data
limit (int): The maximum number of data to return
offset (int): The offset of the data to return
"""
def __init__(
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
) -> None:
self.conditions = conditions
self.limit = limit
self.offset = offset
@PublicAPI(stability="beta")
class StorageInterface(Generic[T, TDataRepresentation], ABC):
"""The storage interface for storing and loading data."""
def __init__(
self,
serializer: Optional[Serializer] = None,
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
):
self._serializer = serializer or JsonSerializer()
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
@property
def serializer(self) -> Serializer:
"""Get the serializer of the storage.
Returns:
Serializer: The serializer of the storage
"""
return self._serializer
@property
def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]:
"""Get the adapter of the storage.
Returns:
StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage
"""
return self._storage_item_adapter
@abstractmethod
def save(self, data: T) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If the data already exists in the storage or data is None
"""
@abstractmethod
def update(self, data: T) -> None:
"""Update the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If data is None
"""
@abstractmethod
def save_or_update(self, data: T) -> None:
"""Save or update the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If data is None
"""
def save_list(self, data: List[T]) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If the data already exists in the storage or data is None
"""
for d in data:
self.save(d)
def save_or_update_list(self, data: List[T]) -> None:
"""Save or update the data to the storage.
Args:
data (T): The data to save
"""
for d in data:
self.save_or_update(d)
@abstractmethod
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
Load data with resource_id will be faster than query data with conditions,
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
Load data with resource_id will be faster than query data with conditions,
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
result = []
for r in resource_id:
item = self.load(r, cls)
if item is not None:
result.append(item)
return result
@abstractmethod
def delete(self, resource_id: ResourceIdentifier) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
"""
@abstractmethod
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
List[T]: The queried data
"""
@abstractmethod
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
int: The number of data
"""
def paginate_query(
self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None
) -> PaginationResult[T]:
"""Paginate the query result.
Args:
page (int): The page number
page_size (int): The number of items per page
cls (Type[T]): The type of the data
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
Returns:
PaginationResult[T]: The pagination result
"""
if spec is None:
spec = QuerySpec(conditions={})
spec.limit = page_size
spec.offset = (page - 1) * page_size
items = self.query(spec, cls)
total = self.count(spec, cls)
return PaginationResult(
items=items,
total_count=total,
total_pages=(total + page_size - 1) // page_size,
page=page,
page_size=page_size,
)
@PublicAPI(stability="alpha")
class InMemoryStorage(StorageInterface[T, T]):
"""The in-memory storage for storing and loading data."""
def __init__(
self,
serializer: Optional[Serializer] = None,
):
super().__init__(serializer)
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
def save(self, data: T) -> None:
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
data.set_serializer(self.serializer)
if data.identifier.str_identifier in self._data:
raise StorageError(
f"Data with identifier {data.identifier.str_identifier} already exists"
)
self._data[data.identifier.str_identifier] = data.serialize()
def update(self, data: T) -> None:
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
data.set_serializer(self.serializer)
self._data[data.identifier.str_identifier] = data.serialize()
def save_or_update(self, data: T) -> None:
self.update(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
serialized_data = self._data.get(resource_id.str_identifier)
if serialized_data is None:
return None
return self.serializer.deserialize(serialized_data, cls)
def delete(self, resource_id: ResourceIdentifier) -> None:
if resource_id.str_identifier in self._data:
del self._data[resource_id.str_identifier]
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
result = []
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
result.append(data)
# Apply limit and offset
if spec.limit is not None:
result = result[spec.offset : spec.offset + spec.limit]
else:
result = result[spec.offset :]
return result
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
count = 0
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
count += 1
return count

View File

View File

@@ -0,0 +1,14 @@
import pytest
from dbgpt.core.interface.storage import InMemoryStorage
from dbgpt.util.serialization.json_serialization import JsonSerializer
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def in_memory_storage(serializer):
return InMemoryStorage(serializer)

View File

@@ -0,0 +1,307 @@
import pytest
from dbgpt.core.interface.tests.conftest import in_memory_storage
from dbgpt.core.interface.message import *
@pytest.fixture
def basic_conversation():
return OnceConversation(chat_mode="chat_normal", user_name="user1", sys_code="sys1")
@pytest.fixture
def human_message():
return HumanMessage(content="Hello")
@pytest.fixture
def ai_message():
return AIMessage(content="Hi there")
@pytest.fixture
def system_message():
return SystemMessage(content="System update")
@pytest.fixture
def view_message():
return ViewMessage(content="View this")
@pytest.fixture
def conversation_identifier():
return ConversationIdentifier("conv1")
@pytest.fixture
def message_identifier():
return MessageIdentifier("conv1", 1)
@pytest.fixture
def message_storage_item():
message = HumanMessage(content="Hello", index=1)
message_detail = message.to_dict()
return MessageStorageItem("conv1", 1, message_detail)
@pytest.fixture
def storage_conversation():
return StorageConversation("conv1", chat_mode="chat_normal", user_name="user1")
@pytest.fixture
def conversation_with_messages():
conv = OnceConversation(chat_mode="chat_normal", user_name="user1")
conv.start_new_round()
conv.add_user_message("Hello")
conv.add_ai_message("Hi")
conv.end_current_round()
conv.start_new_round()
conv.add_user_message("How are you?")
conv.add_ai_message("I'm good, thanks")
conv.end_current_round()
return conv
def test_init(basic_conversation):
assert basic_conversation.chat_mode == "chat_normal"
assert basic_conversation.user_name == "user1"
assert basic_conversation.sys_code == "sys1"
assert basic_conversation.messages == []
assert basic_conversation.start_date == ""
assert basic_conversation.chat_order == 0
assert basic_conversation.model_name == ""
assert basic_conversation.param_type == ""
assert basic_conversation.param_value == ""
assert basic_conversation.cost == 0
assert basic_conversation.tokens == 0
assert basic_conversation._message_index == 0
def test_add_user_message(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], HumanMessage)
def test_add_ai_message(basic_conversation, ai_message):
basic_conversation.add_ai_message(ai_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], AIMessage)
def test_add_system_message(basic_conversation, system_message):
basic_conversation.add_system_message(system_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], SystemMessage)
def test_add_view_message(basic_conversation, view_message):
basic_conversation.add_view_message(view_message.content)
assert len(basic_conversation.messages) == 1
assert isinstance(basic_conversation.messages[0], ViewMessage)
def test_set_start_time(basic_conversation):
now = datetime.now()
basic_conversation.set_start_time(now)
assert basic_conversation.start_date == now.strftime("%Y-%m-%d %H:%M:%S")
def test_clear_messages(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
basic_conversation.clear()
assert len(basic_conversation.messages) == 0
def test_get_latest_user_message(basic_conversation, human_message):
basic_conversation.add_user_message(human_message.content)
latest_message = basic_conversation.get_latest_user_message()
assert latest_message == human_message
def test_get_system_messages(basic_conversation, system_message):
basic_conversation.add_system_message(system_message.content)
system_messages = basic_conversation.get_system_messages()
assert len(system_messages) == 1
assert system_messages[0] == system_message
def test_from_conversation(basic_conversation):
new_conversation = OnceConversation(chat_mode="chat_advanced", user_name="user2")
basic_conversation.from_conversation(new_conversation)
assert basic_conversation.chat_mode == "chat_advanced"
assert basic_conversation.user_name == "user2"
def test_get_messages_by_round(conversation_with_messages):
# Test first round
round1_messages = conversation_with_messages.get_messages_by_round(1)
assert len(round1_messages) == 2
assert round1_messages[0].content == "Hello"
assert round1_messages[1].content == "Hi"
# Test not existing round
no_messages = conversation_with_messages.get_messages_by_round(3)
assert len(no_messages) == 0
def test_get_latest_round(conversation_with_messages):
latest_round_messages = conversation_with_messages.get_latest_round()
assert len(latest_round_messages) == 2
assert latest_round_messages[0].content == "How are you?"
assert latest_round_messages[1].content == "I'm good, thanks"
def test_get_messages_with_round(conversation_with_messages):
# Test last round
last_round_messages = conversation_with_messages.get_messages_with_round(1)
assert len(last_round_messages) == 2
assert last_round_messages[0].content == "How are you?"
assert last_round_messages[1].content == "I'm good, thanks"
# Test last two rounds
last_two_rounds_messages = conversation_with_messages.get_messages_with_round(2)
assert len(last_two_rounds_messages) == 4
assert last_two_rounds_messages[0].content == "Hello"
assert last_two_rounds_messages[1].content == "Hi"
def test_get_model_messages(conversation_with_messages):
model_messages = conversation_with_messages.get_model_messages()
assert len(model_messages) == 4
assert all(isinstance(msg, ModelMessage) for msg in model_messages)
assert model_messages[0].content == "Hello"
assert model_messages[1].content == "Hi"
assert model_messages[2].content == "How are you?"
assert model_messages[3].content == "I'm good, thanks"
def test_conversation_identifier(conversation_identifier):
assert conversation_identifier.conv_uid == "conv1"
assert conversation_identifier.identifier_type == "conversation"
assert conversation_identifier.str_identifier == "conversation:conv1"
assert conversation_identifier.to_dict() == {
"conv_uid": "conv1",
"identifier_type": "conversation",
}
def test_message_identifier(message_identifier):
assert message_identifier.conv_uid == "conv1"
assert message_identifier.index == 1
assert message_identifier.identifier_type == "message"
assert message_identifier.str_identifier == "message___conv1___1"
assert message_identifier.to_dict() == {
"conv_uid": "conv1",
"index": 1,
"identifier_type": "message",
}
def test_message_storage_item(message_storage_item):
assert message_storage_item.conv_uid == "conv1"
assert message_storage_item.index == 1
assert message_storage_item.message_detail == {
"type": "human",
"data": {
"content": "Hello",
"index": 1,
"round_index": 0,
"additional_kwargs": {},
"example": False,
},
"index": 1,
"round_index": 0,
}
assert isinstance(message_storage_item.identifier, MessageIdentifier)
assert message_storage_item.to_dict() == {
"conv_uid": "conv1",
"index": 1,
"message_detail": {
"type": "human",
"index": 1,
"data": {
"content": "Hello",
"index": 1,
"round_index": 0,
"additional_kwargs": {},
"example": False,
},
"round_index": 0,
},
}
assert isinstance(message_storage_item.to_message(), BaseMessage)
def test_storage_conversation_init(storage_conversation):
assert storage_conversation.conv_uid == "conv1"
assert storage_conversation.chat_mode == "chat_normal"
assert storage_conversation.user_name == "user1"
def test_storage_conversation_add_user_message(storage_conversation):
storage_conversation.add_user_message("Hi")
assert len(storage_conversation.messages) == 1
assert isinstance(storage_conversation.messages[0], HumanMessage)
def test_storage_conversation_add_ai_message(storage_conversation):
storage_conversation.add_ai_message("Hello")
assert len(storage_conversation.messages) == 1
assert isinstance(storage_conversation.messages[0], AIMessage)
def test_save_to_storage(storage_conversation, in_memory_storage):
# Set storage
storage_conversation.conv_storage = in_memory_storage
storage_conversation.message_storage = in_memory_storage
# Add messages
storage_conversation.add_user_message("User message")
storage_conversation.add_ai_message("AI response")
# Save to storage
storage_conversation.save_to_storage()
# Create a new StorageConversation instance to load the data
saved_conversation = StorageConversation(
storage_conversation.conv_uid,
conv_storage=in_memory_storage,
message_storage=in_memory_storage,
)
assert saved_conversation.conv_uid == storage_conversation.conv_uid
assert len(saved_conversation.messages) == 2
assert isinstance(saved_conversation.messages[0], HumanMessage)
assert isinstance(saved_conversation.messages[1], AIMessage)
def test_load_from_storage(storage_conversation, in_memory_storage):
# Set storage
storage_conversation.conv_storage = in_memory_storage
storage_conversation.message_storage = in_memory_storage
# Add messages and save to storage
storage_conversation.add_user_message("User message")
storage_conversation.add_ai_message("AI response")
storage_conversation.save_to_storage()
# Create a new StorageConversation instance to load the data
new_conversation = StorageConversation(
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
)
# Check if the data is loaded correctly
assert new_conversation.conv_uid == storage_conversation.conv_uid
assert len(new_conversation.messages) == 2
assert new_conversation.messages[0].content == "User message"
assert new_conversation.messages[1].content == "AI response"
assert isinstance(new_conversation.messages[0], HumanMessage)
assert isinstance(new_conversation.messages[1], AIMessage)

View File

@@ -0,0 +1,129 @@
import pytest
from typing import Dict, Type, Union
from dbgpt.core.interface.storage import (
ResourceIdentifier,
StorageError,
QuerySpec,
InMemoryStorage,
StorageItem,
)
from dbgpt.util.serialization.json_serialization import JsonSerializer
class MockResourceIdentifier(ResourceIdentifier):
def __init__(self, identifier: str):
self._identifier = identifier
@property
def str_identifier(self) -> str:
return self._identifier
def to_dict(self) -> Dict:
return {"identifier": self._identifier}
class MockStorageItem(StorageItem):
def merge(self, other: "StorageItem") -> None:
if not isinstance(other, MockStorageItem):
raise ValueError("other must be a MockStorageItem")
self.data = other.data
def __init__(self, identifier: Union[str, MockResourceIdentifier], data):
self._identifier_str = (
identifier if isinstance(identifier, str) else identifier.str_identifier
)
self.data = data
def to_dict(self) -> Dict:
return {"identifier": self._identifier_str, "data": self.data}
@property
def identifier(self) -> ResourceIdentifier:
return MockResourceIdentifier(self._identifier_str)
@pytest.fixture
def serializer():
return JsonSerializer()
@pytest.fixture
def in_memory_storage(serializer):
return InMemoryStorage(serializer)
def test_save_and_load(in_memory_storage):
resource_id = MockResourceIdentifier("1")
item = MockStorageItem(resource_id, "test_data")
in_memory_storage.save(item)
loaded_item = in_memory_storage.load(resource_id, MockStorageItem)
assert loaded_item.data == "test_data"
def test_duplicate_save(in_memory_storage):
item = MockStorageItem("1", "test_data")
in_memory_storage.save(item)
# Should raise StorageError when saving the same data
with pytest.raises(StorageError):
in_memory_storage.save(item)
def test_delete(in_memory_storage):
resource_id = MockResourceIdentifier("1")
item = MockStorageItem(resource_id, "test_data")
in_memory_storage.save(item)
in_memory_storage.delete(resource_id)
# Storage should not contain the data after deletion
assert in_memory_storage.load(resource_id, MockStorageItem) is None
def test_query(in_memory_storage):
resource_id1 = MockResourceIdentifier("1")
item1 = MockStorageItem(resource_id1, "test_data1")
resource_id2 = MockResourceIdentifier("2")
item2 = MockStorageItem(resource_id2, "test_data2")
in_memory_storage.save(item1)
in_memory_storage.save(item2)
query_spec = QuerySpec(conditions={"data": "test_data1"})
results = in_memory_storage.query(query_spec, MockStorageItem)
assert len(results) == 1
assert results[0].data == "test_data1"
def test_count(in_memory_storage):
item1 = MockStorageItem("1", "test_data1")
item2 = MockStorageItem("2", "test_data2")
in_memory_storage.save(item1)
in_memory_storage.save(item2)
query_spec = QuerySpec(conditions={})
count = in_memory_storage.count(query_spec, MockStorageItem)
assert count == 2
def test_paginate_query(in_memory_storage):
for i in range(10):
resource_id = MockResourceIdentifier(str(i))
item = MockStorageItem(resource_id, f"test_data{i}")
in_memory_storage.save(item)
page_size = 3
query_spec = QuerySpec(conditions={})
page_result = in_memory_storage.paginate_query(
2, page_size, MockStorageItem, query_spec
)
assert len(page_result.items) == page_size
assert page_result.total_count == 10
assert page_result.total_pages == 4
assert page_result.page == 2