mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 01:04:43 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
@@ -9,6 +9,10 @@ from dbgpt.core.interface.message import (
|
||||
ModelMessage,
|
||||
ModelMessageRoleType,
|
||||
OnceConversation,
|
||||
StorageConversation,
|
||||
MessageStorageItem,
|
||||
ConversationIdentifier,
|
||||
MessageIdentifier,
|
||||
)
|
||||
from dbgpt.core.interface.prompt import PromptTemplate, PromptTemplateOperator
|
||||
from dbgpt.core.interface.output_parser import BaseOutputParser, SQLOutputParser
|
||||
@@ -20,6 +24,16 @@ from dbgpt.core.interface.cache import (
|
||||
CachePolicy,
|
||||
CacheConfig,
|
||||
)
|
||||
from dbgpt.core.interface.storage import (
|
||||
ResourceIdentifier,
|
||||
StorageItem,
|
||||
StorageItemAdapter,
|
||||
StorageInterface,
|
||||
InMemoryStorage,
|
||||
DefaultStorageItemAdapter,
|
||||
QuerySpec,
|
||||
StorageError,
|
||||
)
|
||||
|
||||
__ALL__ = [
|
||||
"ModelInferenceMetrics",
|
||||
@@ -30,6 +44,10 @@ __ALL__ = [
|
||||
"ModelMessage",
|
||||
"ModelMessageRoleType",
|
||||
"OnceConversation",
|
||||
"StorageConversation",
|
||||
"MessageStorageItem",
|
||||
"ConversationIdentifier",
|
||||
"MessageIdentifier",
|
||||
"PromptTemplate",
|
||||
"PromptTemplateOperator",
|
||||
"BaseOutputParser",
|
||||
@@ -41,4 +59,12 @@ __ALL__ = [
|
||||
"CacheClient",
|
||||
"CachePolicy",
|
||||
"CacheConfig",
|
||||
"ResourceIdentifier",
|
||||
"StorageItem",
|
||||
"StorageItemAdapter",
|
||||
"StorageInterface",
|
||||
"InMemoryStorage",
|
||||
"DefaultStorageItemAdapter",
|
||||
"QuerySpec",
|
||||
"StorageError",
|
||||
]
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
import threading
|
||||
import asyncio
|
||||
from ..dag import DAG, DAGContext
|
||||
from ..base import DAG, DAGVar
|
||||
|
||||
|
||||
def test_dag_context_sync():
|
||||
@@ -9,18 +9,18 @@ def test_dag_context_sync():
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
with dag1:
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
assert DAGVar.get_current_dag() == dag1
|
||||
with dag2:
|
||||
assert DAGContext.get_current_dag() == dag2
|
||||
assert DAGContext.get_current_dag() == dag1
|
||||
assert DAGContext.get_current_dag() is None
|
||||
assert DAGVar.get_current_dag() == dag2
|
||||
assert DAGVar.get_current_dag() == dag1
|
||||
assert DAGVar.get_current_dag() is None
|
||||
|
||||
|
||||
def test_dag_context_threading():
|
||||
def thread_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
DAGVar.enter_dag(dag)
|
||||
assert DAGVar.get_current_dag() == dag
|
||||
DAGVar.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
@@ -33,19 +33,19 @@ def test_dag_context_threading():
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
||||
assert DAGVar.get_current_dag() is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dag_context_async():
|
||||
async def async_function(dag):
|
||||
DAGContext.enter_dag(dag)
|
||||
assert DAGContext.get_current_dag() == dag
|
||||
DAGContext.exit_dag()
|
||||
DAGVar.enter_dag(dag)
|
||||
assert DAGVar.get_current_dag() == dag
|
||||
DAGVar.exit_dag()
|
||||
|
||||
dag1 = DAG("dag1")
|
||||
dag2 = DAG("dag2")
|
||||
|
||||
await asyncio.gather(async_function(dag1), async_function(dag2))
|
||||
|
||||
assert DAGContext.get_current_dag() is None
|
||||
assert DAGVar.get_current_dag() is None
|
||||
|
@@ -1,16 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from dbgpt._private.pydantic import BaseModel, Field
|
||||
|
||||
from dbgpt.core.interface.storage import (
|
||||
ResourceIdentifier,
|
||||
StorageItem,
|
||||
StorageInterface,
|
||||
InMemoryStorage,
|
||||
)
|
||||
|
||||
|
||||
class BaseMessage(BaseModel, ABC):
|
||||
"""Message object."""
|
||||
|
||||
content: str
|
||||
index: int = 0
|
||||
round_index: int = 0
|
||||
"""The round index of the message in the conversation"""
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
@@ -18,6 +28,24 @@ class BaseMessage(BaseModel, ABC):
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model"""
|
||||
return True
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict
|
||||
|
||||
Returns:
|
||||
Dict: The dict object
|
||||
"""
|
||||
return {
|
||||
"type": self.type,
|
||||
"data": self.dict(),
|
||||
"index": self.index,
|
||||
"round_index": self.round_index,
|
||||
}
|
||||
|
||||
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
@@ -51,6 +79,14 @@ class ViewMessage(BaseMessage):
|
||||
"""Type of the message, used for serialization."""
|
||||
return "view"
|
||||
|
||||
@property
|
||||
def pass_to_model(self) -> bool:
|
||||
"""Whether the message will be passed to the model
|
||||
|
||||
The view message will not be passed to the model
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class SystemMessage(BaseMessage):
|
||||
"""Type of message that is a system message."""
|
||||
@@ -141,15 +177,15 @@ class ModelMessage(BaseModel):
|
||||
return ModelMessage(role=ModelMessageRoleType.HUMAN, content=content)
|
||||
|
||||
|
||||
def _message_to_dict(message: BaseMessage) -> dict:
|
||||
return {"type": message.type, "data": message.dict()}
|
||||
def _message_to_dict(message: BaseMessage) -> Dict:
|
||||
return message.to_dict()
|
||||
|
||||
|
||||
def _messages_to_dict(messages: List[BaseMessage]) -> List[dict]:
|
||||
def _messages_to_dict(messages: List[BaseMessage]) -> List[Dict]:
|
||||
return [_message_to_dict(m) for m in messages]
|
||||
|
||||
|
||||
def _message_from_dict(message: dict) -> BaseMessage:
|
||||
def _message_from_dict(message: Dict) -> BaseMessage:
|
||||
_type = message["type"]
|
||||
if _type == "human":
|
||||
return HumanMessage(**message["data"])
|
||||
@@ -163,7 +199,7 @@ def _message_from_dict(message: dict) -> BaseMessage:
|
||||
raise ValueError(f"Got unexpected type: {_type}")
|
||||
|
||||
|
||||
def _messages_from_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
def _messages_from_dict(messages: List[Dict]) -> List[BaseMessage]:
|
||||
return [_message_from_dict(m) for m in messages]
|
||||
|
||||
|
||||
@@ -193,50 +229,119 @@ def _parse_model_messages(
|
||||
history_messages.append([])
|
||||
if messages[-1].role != "human":
|
||||
raise ValueError("Hi! What do you want to talk about?")
|
||||
# Keep message pair of [user message, assistant message]
|
||||
# Keep message a pair of [user message, assistant message]
|
||||
history_messages = list(filter(lambda x: len(x) == 2, history_messages))
|
||||
user_prompt = messages[-1].content
|
||||
return user_prompt, system_messages, history_messages
|
||||
|
||||
|
||||
class OnceConversation:
|
||||
"""
|
||||
All the information of a conversation, the current single service in memory, can expand cache and database support distributed services
|
||||
"""All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, chat_mode, user_name: str = None, sys_code: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
chat_mode: str,
|
||||
user_name: str = None,
|
||||
sys_code: str = None,
|
||||
summary: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.chat_mode: str = chat_mode
|
||||
self.messages: List[BaseMessage] = []
|
||||
self.start_date: str = ""
|
||||
self.chat_order: int = 0
|
||||
self.model_name: str = ""
|
||||
self.param_type: str = ""
|
||||
self.param_value: str = ""
|
||||
self.cost: int = 0
|
||||
self.tokens: int = 0
|
||||
self.user_name: str = user_name
|
||||
self.sys_code: str = sys_code
|
||||
self.summary: str = summary
|
||||
|
||||
def add_user_message(self, message: str) -> None:
|
||||
"""Add a user message to the store"""
|
||||
has_message = any(
|
||||
isinstance(instance, HumanMessage) for instance in self.messages
|
||||
)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Human message")
|
||||
self.messages.append(HumanMessage(content=message))
|
||||
self.messages: List[BaseMessage] = kwargs.get("messages", [])
|
||||
self.start_date: str = kwargs.get("start_date", "")
|
||||
# After each complete round of dialogue, the current value will be increased by 1
|
||||
self.chat_order: int = int(kwargs.get("chat_order", 0))
|
||||
self.model_name: str = kwargs.get("model_name", "")
|
||||
self.param_type: str = kwargs.get("param_type", "")
|
||||
self.param_value: str = kwargs.get("param_value", "")
|
||||
self.cost: int = int(kwargs.get("cost", 0))
|
||||
self.tokens: int = int(kwargs.get("tokens", 0))
|
||||
self._message_index: int = int(kwargs.get("message_index", 0))
|
||||
|
||||
def add_ai_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
def _append_message(self, message: BaseMessage) -> None:
|
||||
index = self._message_index
|
||||
self._message_index += 1
|
||||
message.index = index
|
||||
message.round_index = self.chat_order
|
||||
self.messages.append(message)
|
||||
|
||||
def start_new_round(self) -> None:
|
||||
"""Start a new round of conversation
|
||||
|
||||
Example:
|
||||
>>> conversation = OnceConversation()
|
||||
>>> # The chat order will be 0, then we start a new round of conversation
|
||||
>>> assert conversation.chat_order == 0
|
||||
>>> conversation.start_new_round()
|
||||
>>> # Now the chat order will be 1
|
||||
>>> assert conversation.chat_order == 1
|
||||
>>> conversation.add_user_message("hello")
|
||||
>>> conversation.add_ai_message("hi")
|
||||
>>> conversation.end_current_round()
|
||||
>>> # Now the chat order will be 1, then we start a new round of conversation
|
||||
>>> conversation.start_new_round()
|
||||
>>> # Now the chat order will be 2
|
||||
>>> assert conversation.chat_order == 2
|
||||
>>> conversation.add_user_message("hello")
|
||||
>>> conversation.add_ai_message("hi")
|
||||
>>> conversation.end_current_round()
|
||||
>>> assert conversation.chat_order == 2
|
||||
"""
|
||||
self.chat_order += 1
|
||||
|
||||
def end_current_round(self) -> None:
|
||||
"""End the current round of conversation
|
||||
|
||||
We do noting here, just for the interface
|
||||
"""
|
||||
pass
|
||||
|
||||
def add_user_message(
|
||||
self, message: str, check_duplicate_type: Optional[bool] = False
|
||||
) -> None:
|
||||
"""Add a user message to the conversation
|
||||
|
||||
Args:
|
||||
message (str): The message content
|
||||
check_duplicate_type (bool): Whether to check the duplicate message type
|
||||
|
||||
Raises:
|
||||
ValueError: If the message is duplicate and check_duplicate_type is True
|
||||
"""
|
||||
if check_duplicate_type:
|
||||
has_message = any(
|
||||
isinstance(instance, HumanMessage) for instance in self.messages
|
||||
)
|
||||
if has_message:
|
||||
raise ValueError("Already Have Human message")
|
||||
self._append_message(HumanMessage(content=message))
|
||||
|
||||
def add_ai_message(
|
||||
self, message: str, update_if_exist: Optional[bool] = False
|
||||
) -> None:
|
||||
"""Add an AI message to the conversation
|
||||
|
||||
Args:
|
||||
message (str): The message content
|
||||
update_if_exist (bool): Whether to update the message if the message type is duplicate
|
||||
"""
|
||||
if not update_if_exist:
|
||||
self._append_message(AIMessage(content=message))
|
||||
return
|
||||
has_message = any(isinstance(instance, AIMessage) for instance in self.messages)
|
||||
if has_message:
|
||||
self.__update_ai_message(message)
|
||||
self._update_ai_message(message)
|
||||
else:
|
||||
self.messages.append(AIMessage(content=message))
|
||||
""" """
|
||||
self._append_message(AIMessage(content=message))
|
||||
|
||||
def __update_ai_message(self, new_message: str) -> None:
|
||||
def _update_ai_message(self, new_message: str) -> None:
|
||||
"""
|
||||
stream out message update
|
||||
Args:
|
||||
@@ -252,13 +357,11 @@ class OnceConversation:
|
||||
|
||||
def add_view_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
|
||||
self.messages.append(ViewMessage(content=message))
|
||||
""" """
|
||||
self._append_message(ViewMessage(content=message))
|
||||
|
||||
def add_system_message(self, message: str) -> None:
|
||||
"""Add an AI message to the store"""
|
||||
self.messages.append(SystemMessage(content=message))
|
||||
"""Add a system message to the store"""
|
||||
self._append_message(SystemMessage(content=message))
|
||||
|
||||
def set_start_time(self, datatime: datetime):
|
||||
dt_str = datatime.strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -267,23 +370,369 @@ class OnceConversation:
|
||||
def clear(self) -> None:
|
||||
"""Remove all messages from the store"""
|
||||
self.messages.clear()
|
||||
self.session_id = None
|
||||
|
||||
def get_user_conv(self):
|
||||
for message in self.messages:
|
||||
def get_latest_user_message(self) -> Optional[HumanMessage]:
|
||||
"""Get the latest user message"""
|
||||
for message in self.messages[::-1]:
|
||||
if isinstance(message, HumanMessage):
|
||||
return message
|
||||
return None
|
||||
|
||||
def get_system_conv(self):
|
||||
system_convs = []
|
||||
def get_system_messages(self) -> List[SystemMessage]:
|
||||
"""Get the latest user message"""
|
||||
return list(filter(lambda x: isinstance(x, SystemMessage), self.messages))
|
||||
|
||||
def _to_dict(self) -> Dict:
|
||||
return _conversation_to_dict(self)
|
||||
|
||||
def from_conversation(self, conversation: OnceConversation) -> None:
|
||||
"""Load the conversation from the storage"""
|
||||
self.chat_mode = conversation.chat_mode
|
||||
self.messages = conversation.messages
|
||||
self.start_date = conversation.start_date
|
||||
self.chat_order = conversation.chat_order
|
||||
self.model_name = conversation.model_name
|
||||
self.param_type = conversation.param_type
|
||||
self.param_value = conversation.param_value
|
||||
self.cost = conversation.cost
|
||||
self.tokens = conversation.tokens
|
||||
self.user_name = conversation.user_name
|
||||
self.sys_code = conversation.sys_code
|
||||
|
||||
def get_messages_by_round(self, round_index: int) -> List[BaseMessage]:
|
||||
"""Get the messages by round index
|
||||
|
||||
Args:
|
||||
round_index (int): The round index
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages
|
||||
"""
|
||||
return list(filter(lambda x: x.round_index == round_index, self.messages))
|
||||
|
||||
def get_latest_round(self) -> List[BaseMessage]:
|
||||
"""Get the latest round messages
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages
|
||||
"""
|
||||
return self.get_messages_by_round(self.chat_order)
|
||||
|
||||
def get_messages_with_round(self, round_count: int) -> List[BaseMessage]:
|
||||
"""Get the messages with round count
|
||||
|
||||
If the round count is 1, the history messages will not be included.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
conversation = OnceConversation()
|
||||
conversation.start_new_round()
|
||||
conversation.add_user_message("hello, this is the first round")
|
||||
conversation.add_ai_message("hi")
|
||||
conversation.end_current_round()
|
||||
conversation.start_new_round()
|
||||
conversation.add_user_message("hello, this is the second round")
|
||||
conversation.add_ai_message("hi")
|
||||
conversation.end_current_round()
|
||||
conversation.start_new_round()
|
||||
conversation.add_user_message("hello, this is the third round")
|
||||
conversation.add_ai_message("hi")
|
||||
conversation.end_current_round()
|
||||
|
||||
assert len(conversation.get_messages_with_round(1)) == 2
|
||||
assert conversation.get_messages_with_round(1)[0].content == "hello, this is the third round"
|
||||
assert conversation.get_messages_with_round(1)[1].content == "hi"
|
||||
|
||||
assert len(conversation.get_messages_with_round(2)) == 4
|
||||
assert conversation.get_messages_with_round(2)[0].content == "hello, this is the second round"
|
||||
assert conversation.get_messages_with_round(2)[1].content == "hi"
|
||||
|
||||
Args:
|
||||
round_count (int): The round count
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: The messages
|
||||
"""
|
||||
latest_round_index = self.chat_order
|
||||
start_round_index = max(1, latest_round_index - round_count + 1)
|
||||
messages = []
|
||||
for round_index in range(start_round_index, latest_round_index + 1):
|
||||
messages.extend(self.get_messages_by_round(round_index))
|
||||
return messages
|
||||
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
"""Get the model messages
|
||||
|
||||
Model messages just include human, ai and system messages.
|
||||
Model messages maybe include the history messages, The order of the messages is the same as the order of
|
||||
the messages in the conversation, the last message is the latest message.
|
||||
|
||||
If you want to hand the message with your own logic, you can override this method.
|
||||
|
||||
Examples:
|
||||
If you not need the history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
for message in self.get_latest_round():
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
)
|
||||
return messages
|
||||
|
||||
If you want to add the one round history messages, you can override this method like this:
|
||||
.. code-block:: python
|
||||
def get_model_messages(self) -> List[ModelMessage]:
|
||||
messages = []
|
||||
latest_round_index = self.chat_order
|
||||
round_count = 1
|
||||
start_round_index = max(1, latest_round_index - round_count + 1)
|
||||
for round_index in range(start_round_index, latest_round_index + 1):
|
||||
for message in self.get_messages_by_round(round_index):
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
)
|
||||
return messages
|
||||
|
||||
Returns:
|
||||
List[ModelMessage]: The model messages
|
||||
"""
|
||||
messages = []
|
||||
for message in self.messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_convs.append(message)
|
||||
return system_convs
|
||||
if message.pass_to_model:
|
||||
messages.append(
|
||||
ModelMessage(role=message.type, content=message.content)
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def _conversation_to_dict(once: OnceConversation) -> dict:
|
||||
class ConversationIdentifier(ResourceIdentifier):
|
||||
"""Conversation identifier"""
|
||||
|
||||
def __init__(self, conv_uid: str, identifier_type: str = "conversation"):
|
||||
self.conv_uid = conv_uid
|
||||
self.identifier_type = identifier_type
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return f"{self.identifier_type}:{self.conv_uid}"
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {"conv_uid": self.conv_uid, "identifier_type": self.identifier_type}
|
||||
|
||||
|
||||
class MessageIdentifier(ResourceIdentifier):
|
||||
"""Message identifier"""
|
||||
|
||||
identifier_split = "___"
|
||||
|
||||
def __init__(self, conv_uid: str, index: int, identifier_type: str = "message"):
|
||||
self.conv_uid = conv_uid
|
||||
self.index = index
|
||||
self.identifier_type = identifier_type
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return f"{self.identifier_type}{self.identifier_split}{self.conv_uid}{self.identifier_split}{self.index}"
|
||||
|
||||
@staticmethod
|
||||
def from_str_identifier(str_identifier: str) -> MessageIdentifier:
|
||||
"""Convert from str identifier
|
||||
|
||||
Args:
|
||||
str_identifier (str): The str identifier
|
||||
|
||||
Returns:
|
||||
MessageIdentifier: The message identifier
|
||||
"""
|
||||
parts = str_identifier.split(MessageIdentifier.identifier_split)
|
||||
if len(parts) != 3:
|
||||
raise ValueError(f"Invalid str identifier: {str_identifier}")
|
||||
return MessageIdentifier(parts[1], int(parts[2]))
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"index": self.index,
|
||||
"identifier_type": self.identifier_type,
|
||||
}
|
||||
|
||||
|
||||
class MessageStorageItem(StorageItem):
|
||||
@property
|
||||
def identifier(self) -> MessageIdentifier:
|
||||
return self._id
|
||||
|
||||
def __init__(self, conv_uid: str, index: int, message_detail: Dict):
|
||||
self.conv_uid = conv_uid
|
||||
self.index = index
|
||||
self.message_detail = message_detail
|
||||
self._id = MessageIdentifier(conv_uid, index)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"conv_uid": self.conv_uid,
|
||||
"index": self.index,
|
||||
"message_detail": self.message_detail,
|
||||
}
|
||||
|
||||
def to_message(self) -> BaseMessage:
|
||||
"""Convert to message object
|
||||
Returns:
|
||||
BaseMessage: The message object
|
||||
|
||||
Raises:
|
||||
ValueError: If the message type is not supported
|
||||
"""
|
||||
return _message_from_dict(self.message_detail)
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other message to self
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other message
|
||||
"""
|
||||
if not isinstance(other, MessageStorageItem):
|
||||
raise ValueError(f"Can not merge {other} to {self}")
|
||||
self.message_detail = other.message_detail
|
||||
|
||||
|
||||
class StorageConversation(OnceConversation, StorageItem):
|
||||
"""All the information of a conversation, the current single service in memory,
|
||||
can expand cache and database support distributed services.
|
||||
|
||||
"""
|
||||
|
||||
@property
|
||||
def identifier(self) -> ConversationIdentifier:
|
||||
return self._id
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
dict_data = self._to_dict()
|
||||
messages: Dict = dict_data.pop("messages")
|
||||
message_ids = []
|
||||
index = 0
|
||||
for message in messages:
|
||||
if "index" in message:
|
||||
message_idx = message["index"]
|
||||
else:
|
||||
message_idx = index
|
||||
index += 1
|
||||
message_ids.append(
|
||||
MessageIdentifier(self.conv_uid, message_idx).str_identifier
|
||||
)
|
||||
# Replace message with message ids
|
||||
dict_data["conv_uid"] = self.conv_uid
|
||||
dict_data["message_ids"] = message_ids
|
||||
dict_data["save_message_independent"] = self.save_message_independent
|
||||
return dict_data
|
||||
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other conversation to self
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other conversation
|
||||
"""
|
||||
if not isinstance(other, StorageConversation):
|
||||
raise ValueError(f"Can not merge {other} to {self}")
|
||||
self.from_conversation(other)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conv_uid: str,
|
||||
chat_mode: str = None,
|
||||
user_name: str = None,
|
||||
sys_code: str = None,
|
||||
message_ids: List[str] = None,
|
||||
summary: str = None,
|
||||
save_message_independent: Optional[bool] = True,
|
||||
conv_storage: StorageInterface = None,
|
||||
message_storage: StorageInterface = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
|
||||
self.conv_uid = conv_uid
|
||||
self._message_ids = message_ids
|
||||
self.save_message_independent = save_message_independent
|
||||
self._id = ConversationIdentifier(conv_uid)
|
||||
if conv_storage is None:
|
||||
conv_storage = InMemoryStorage()
|
||||
if message_storage is None:
|
||||
message_storage = InMemoryStorage()
|
||||
self.conv_storage = conv_storage
|
||||
self.message_storage = message_storage
|
||||
# Load from storage
|
||||
self.load_from_storage(self.conv_storage, self.message_storage)
|
||||
|
||||
@property
|
||||
def message_ids(self) -> List[str]:
|
||||
"""Get the message ids
|
||||
|
||||
Returns:
|
||||
List[str]: The message ids
|
||||
"""
|
||||
return self._message_ids if self._message_ids else []
|
||||
|
||||
def end_current_round(self) -> None:
|
||||
"""End the current round of conversation
|
||||
|
||||
Save the conversation to the storage after a round of conversation
|
||||
"""
|
||||
self.save_to_storage()
|
||||
|
||||
def _get_message_items(self) -> List[MessageStorageItem]:
|
||||
return [
|
||||
MessageStorageItem(self.conv_uid, message.index, message.to_dict())
|
||||
for message in self.messages
|
||||
]
|
||||
|
||||
def save_to_storage(self) -> None:
|
||||
"""Save the conversation to the storage"""
|
||||
# Save messages first
|
||||
message_list = self._get_message_items()
|
||||
self._message_ids = [
|
||||
message.identifier.str_identifier for message in message_list
|
||||
]
|
||||
self.message_storage.save_list(message_list)
|
||||
# Save conversation
|
||||
self.conv_storage.save_or_update(self)
|
||||
|
||||
def load_from_storage(
|
||||
self, conv_storage: StorageInterface, message_storage: StorageInterface
|
||||
) -> None:
|
||||
"""Load the conversation from the storage
|
||||
|
||||
Warning: This will overwrite the current conversation.
|
||||
|
||||
Args:
|
||||
conv_storage (StorageInterface): The storage interface
|
||||
message_storage (StorageInterface): The storage interface
|
||||
"""
|
||||
# Load conversation first
|
||||
conversation: StorageConversation = conv_storage.load(
|
||||
self._id, StorageConversation
|
||||
)
|
||||
if conversation is None:
|
||||
return
|
||||
message_ids = conversation._message_ids or []
|
||||
|
||||
# Load messages
|
||||
message_list = message_storage.load_list(
|
||||
[
|
||||
MessageIdentifier.from_str_identifier(message_id)
|
||||
for message_id in message_ids
|
||||
],
|
||||
MessageStorageItem,
|
||||
)
|
||||
messages = [message.to_message() for message in message_list]
|
||||
conversation.messages = messages
|
||||
self._message_ids = message_ids
|
||||
self.from_conversation(conversation)
|
||||
|
||||
|
||||
def _conversation_to_dict(once: OnceConversation) -> Dict:
|
||||
start_str: str = ""
|
||||
if hasattr(once, "start_date") and once.start_date:
|
||||
if isinstance(once.start_date, datetime):
|
||||
@@ -303,6 +752,7 @@ def _conversation_to_dict(once: OnceConversation) -> dict:
|
||||
"param_value": once.param_value,
|
||||
"user_name": once.user_name,
|
||||
"sys_code": once.sys_code,
|
||||
"summary": once.summary if once.summary else "",
|
||||
}
|
||||
|
||||
|
||||
|
@@ -92,7 +92,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
f"""Model server error!code={resp_obj_ex["error_code"]}, errmsg is {resp_obj_ex["text"]}"""
|
||||
)
|
||||
|
||||
def __illegal_json_ends(self, s):
|
||||
def _illegal_json_ends(self, s):
|
||||
temp_json = s
|
||||
illegal_json_ends_1 = [", }", ",}"]
|
||||
illegal_json_ends_2 = ", ]", ",]"
|
||||
@@ -102,25 +102,25 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
temp_json = temp_json.replace(illegal_json_end, " ]")
|
||||
return temp_json
|
||||
|
||||
def __extract_json(self, s):
|
||||
def _extract_json(self, s):
|
||||
try:
|
||||
# Get the dual-mode analysis first and get the maximum result
|
||||
temp_json_simple = self.__json_interception(s)
|
||||
temp_json_array = self.__json_interception(s, True)
|
||||
temp_json_simple = self._json_interception(s)
|
||||
temp_json_array = self._json_interception(s, True)
|
||||
if len(temp_json_simple) > len(temp_json_array):
|
||||
temp_json = temp_json_simple
|
||||
else:
|
||||
temp_json = temp_json_array
|
||||
|
||||
if not temp_json:
|
||||
temp_json = self.__json_interception(s)
|
||||
temp_json = self._json_interception(s)
|
||||
|
||||
temp_json = self.__illegal_json_ends(temp_json)
|
||||
temp_json = self._illegal_json_ends(temp_json)
|
||||
return temp_json
|
||||
except Exception as e:
|
||||
raise ValueError("Failed to find a valid json in LLM response!" + temp_json)
|
||||
|
||||
def __json_interception(self, s, is_json_array: bool = False):
|
||||
def _json_interception(self, s, is_json_array: bool = False):
|
||||
try:
|
||||
if is_json_array:
|
||||
i = s.find("[")
|
||||
@@ -176,7 +176,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
cleaned_output = cleaned_output.strip()
|
||||
if not cleaned_output.startswith("{") or not cleaned_output.endswith("}"):
|
||||
logger.info("illegal json processing:\n" + cleaned_output)
|
||||
cleaned_output = self.__extract_json(cleaned_output)
|
||||
cleaned_output = self._extract_json(cleaned_output)
|
||||
|
||||
if not cleaned_output or len(cleaned_output) <= 0:
|
||||
return model_out_text
|
||||
@@ -188,7 +188,7 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
.replace("\\", " ")
|
||||
.replace("\_", "_")
|
||||
)
|
||||
cleaned_output = self.__illegal_json_ends(cleaned_output)
|
||||
cleaned_output = self._illegal_json_ends(cleaned_output)
|
||||
return cleaned_output
|
||||
|
||||
def parse_view_response(
|
||||
@@ -208,20 +208,6 @@ class BaseOutputParser(MapOperator[ModelOutput, Any], ABC):
|
||||
"""Instructions on how the LLM output should be formatted."""
|
||||
raise NotImplementedError
|
||||
|
||||
# @property
|
||||
# def _type(self) -> str:
|
||||
# """Return the type key."""
|
||||
# raise NotImplementedError(
|
||||
# f"_type property is not implemented in class {self.__class__.__name__}."
|
||||
# " This is required for serialization."
|
||||
# )
|
||||
|
||||
def dict(self, **kwargs: Any) -> Dict:
|
||||
"""Return dictionary representation of output parser."""
|
||||
output_parser_dict = super().dict()
|
||||
output_parser_dict["_type"] = self._type
|
||||
return output_parser_dict
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> Any:
|
||||
"""Parse the output of an LLM call.
|
||||
|
||||
|
@@ -1,19 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Type, Dict
|
||||
|
||||
|
||||
class Serializable(ABC):
|
||||
serializer: "Serializer" = None
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the object's state to a dictionary."""
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Convert the object into bytes for storage or transmission.
|
||||
|
||||
Returns:
|
||||
bytes: The byte array after serialization
|
||||
"""
|
||||
if self.serializer is None:
|
||||
raise ValueError(
|
||||
"Serializer is not set. Please set the serializer before serialization."
|
||||
)
|
||||
return self.serializer.serialize(self)
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert the object's state to a dictionary."""
|
||||
def set_serializer(self, serializer: "Serializer") -> None:
|
||||
"""Set the serializer for current serializable object.
|
||||
|
||||
Args:
|
||||
serializer (Serializer): The serializer to set
|
||||
"""
|
||||
self.serializer = serializer
|
||||
|
||||
|
||||
class Serializer(ABC):
|
||||
|
409
dbgpt/core/interface/storage.py
Normal file
409
dbgpt/core/interface/storage.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class ResourceIdentifier(Serializable, ABC):
|
||||
"""The resource identifier interface for resource identifiers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def str_identifier(self) -> str:
|
||||
"""Get the string identifier of the resource.
|
||||
|
||||
The string identifier is used to uniquely identify the resource.
|
||||
|
||||
Returns:
|
||||
str: The string identifier of the resource
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of the key."""
|
||||
return hash(self.str_identifier)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality with another key."""
|
||||
if not isinstance(other, ResourceIdentifier):
|
||||
return False
|
||||
return self.str_identifier == other.str_identifier
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class StorageItem(Serializable, ABC):
|
||||
"""The storage item interface for storage items."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
"""Get the resource identifier of the storage item.
|
||||
|
||||
Returns:
|
||||
ResourceIdentifier: The resource identifier of the storage item
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other storage item into the current storage item.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other storage item
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T", bound=StorageItem)
|
||||
TDataRepresentation = TypeVar("TDataRepresentation")
|
||||
|
||||
|
||||
class StorageItemAdapter(Generic[T, TDataRepresentation]):
|
||||
"""The storage item adapter for converting storage items to and from the storage format.
|
||||
|
||||
Sometimes, the storage item is not the same as the storage format,
|
||||
so we need to convert the storage item to the storage format and vice versa.
|
||||
|
||||
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_storage_format(self, item: T) -> TDataRepresentation:
|
||||
"""Convert the storage item to the storage format.
|
||||
|
||||
Args:
|
||||
item (T): The storage item
|
||||
|
||||
Returns:
|
||||
TDataRepresentation: The data in the storage format
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def from_storage_format(self, data: TDataRepresentation) -> T:
|
||||
"""Convert the storage format to the storage item.
|
||||
|
||||
Args:
|
||||
data (TDataRepresentation): The data in the storage format
|
||||
|
||||
Returns:
|
||||
T: The storage item
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[TDataRepresentation],
|
||||
resource_id: ResourceIdentifier,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Get the query for the resource identifier.
|
||||
|
||||
Args:
|
||||
storage_format (Type[TDataRepresentation]): The storage format
|
||||
resource_id (ResourceIdentifier): The resource identifier
|
||||
kwargs: The additional arguments
|
||||
|
||||
Returns:
|
||||
Any: The query for the resource identifier
|
||||
"""
|
||||
|
||||
|
||||
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
|
||||
"""The default storage item adapter for converting storage items to and from the storage format.
|
||||
|
||||
The storage item is the same as the storage format, so no conversion is required.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: T) -> T:
|
||||
return item
|
||||
|
||||
def from_storage_format(self, data: T) -> T:
|
||||
return data
|
||||
|
||||
def get_query_for_identifier(
|
||||
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class StorageError(Exception):
|
||||
"""The base exception class for storage errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class QuerySpec:
|
||||
"""The query specification for querying data from the storage.
|
||||
|
||||
Attributes:
|
||||
conditions (Dict[str, Any]): The conditions for querying data
|
||||
limit (int): The maximum number of data to return
|
||||
offset (int): The offset of the data to return
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
|
||||
) -> None:
|
||||
self.conditions = conditions
|
||||
self.limit = limit
|
||||
self.offset = offset
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
"""The storage interface for storing and loading data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serializer: Optional[Serializer] = None,
|
||||
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
|
||||
):
|
||||
self._serializer = serializer or JsonSerializer()
|
||||
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get the serializer of the storage.
|
||||
|
||||
Returns:
|
||||
Serializer: The serializer of the storage
|
||||
"""
|
||||
return self._serializer
|
||||
|
||||
@property
|
||||
def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]:
|
||||
"""Get the adapter of the storage.
|
||||
|
||||
Returns:
|
||||
StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage
|
||||
"""
|
||||
return self._storage_item_adapter
|
||||
|
||||
@abstractmethod
|
||||
def save(self, data: T) -> None:
|
||||
"""Save the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If the data already exists in the storage or data is None
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update(self, data: T) -> None:
|
||||
"""Update the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If data is None
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_or_update(self, data: T) -> None:
|
||||
"""Save or update the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If data is None
|
||||
"""
|
||||
|
||||
def save_list(self, data: List[T]) -> None:
|
||||
"""Save the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If the data already exists in the storage or data is None
|
||||
"""
|
||||
for d in data:
|
||||
self.save(d)
|
||||
|
||||
def save_or_update_list(self, data: List[T]) -> None:
|
||||
"""Save or update the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
"""
|
||||
for d in data:
|
||||
self.save_or_update(d)
|
||||
|
||||
@abstractmethod
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
"""Load the data from the storage.
|
||||
|
||||
None will be returned if the data does not exist in the storage.
|
||||
|
||||
Load data with resource_id will be faster than query data with conditions,
|
||||
so we suggest to use load if possible.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
Optional[T]: The loaded data
|
||||
"""
|
||||
|
||||
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
|
||||
"""Load the data from the storage.
|
||||
|
||||
None will be returned if the data does not exist in the storage.
|
||||
|
||||
Load data with resource_id will be faster than query data with conditions,
|
||||
so we suggest to use load if possible.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
Optional[T]: The loaded data
|
||||
"""
|
||||
result = []
|
||||
for r in resource_id:
|
||||
item = self.load(r, cls)
|
||||
if item is not None:
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
"""Delete the data from the storage.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
"""Query data from the storage.
|
||||
|
||||
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
List[T]: The queried data
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||
"""Count the number of data from the storage.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
int: The number of data
|
||||
"""
|
||||
|
||||
def paginate_query(
|
||||
self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None
|
||||
) -> PaginationResult[T]:
|
||||
"""Paginate the query result.
|
||||
|
||||
Args:
|
||||
page (int): The page number
|
||||
page_size (int): The number of items per page
|
||||
cls (Type[T]): The type of the data
|
||||
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
|
||||
|
||||
Returns:
|
||||
PaginationResult[T]: The pagination result
|
||||
"""
|
||||
if spec is None:
|
||||
spec = QuerySpec(conditions={})
|
||||
spec.limit = page_size
|
||||
spec.offset = (page - 1) * page_size
|
||||
items = self.query(spec, cls)
|
||||
total = self.count(spec, cls)
|
||||
return PaginationResult(
|
||||
items=items,
|
||||
total_count=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""The in-memory storage for storing and loading data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serializer: Optional[Serializer] = None,
|
||||
):
|
||||
super().__init__(serializer)
|
||||
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
|
||||
|
||||
def save(self, data: T) -> None:
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
|
||||
if data.identifier.str_identifier in self._data:
|
||||
raise StorageError(
|
||||
f"Data with identifier {data.identifier.str_identifier} already exists"
|
||||
)
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
def update(self, data: T) -> None:
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
def save_or_update(self, data: T) -> None:
|
||||
self.update(data)
|
||||
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
serialized_data = self._data.get(resource_id.str_identifier)
|
||||
if serialized_data is None:
|
||||
return None
|
||||
return self.serializer.deserialize(serialized_data, cls)
|
||||
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
if resource_id.str_identifier in self._data:
|
||||
del self._data[resource_id.str_identifier]
|
||||
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
result = []
|
||||
for serialized_data in self._data.values():
|
||||
data = self._serializer.deserialize(serialized_data, cls)
|
||||
if all(
|
||||
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||
):
|
||||
result.append(data)
|
||||
|
||||
# Apply limit and offset
|
||||
if spec.limit is not None:
|
||||
result = result[spec.offset : spec.offset + spec.limit]
|
||||
else:
|
||||
result = result[spec.offset :]
|
||||
return result
|
||||
|
||||
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||
count = 0
|
||||
for serialized_data in self._data.values():
|
||||
data = self._serializer.deserialize(serialized_data, cls)
|
||||
if all(
|
||||
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||
):
|
||||
count += 1
|
||||
return count
|
0
dbgpt/core/interface/tests/__init__.py
Normal file
0
dbgpt/core/interface/tests/__init__.py
Normal file
14
dbgpt/core/interface/tests/conftest.py
Normal file
14
dbgpt/core/interface/tests/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.storage import InMemoryStorage
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serializer():
|
||||
return JsonSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_storage(serializer):
|
||||
return InMemoryStorage(serializer)
|
307
dbgpt/core/interface/tests/test_message.py
Normal file
307
dbgpt/core/interface/tests/test_message.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.core.interface.tests.conftest import in_memory_storage
|
||||
from dbgpt.core.interface.message import *
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_conversation():
|
||||
return OnceConversation(chat_mode="chat_normal", user_name="user1", sys_code="sys1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def human_message():
|
||||
return HumanMessage(content="Hello")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_message():
|
||||
return AIMessage(content="Hi there")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def system_message():
|
||||
return SystemMessage(content="System update")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def view_message():
|
||||
return ViewMessage(content="View this")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_identifier():
|
||||
return ConversationIdentifier("conv1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_identifier():
|
||||
return MessageIdentifier("conv1", 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def message_storage_item():
|
||||
message = HumanMessage(content="Hello", index=1)
|
||||
message_detail = message.to_dict()
|
||||
return MessageStorageItem("conv1", 1, message_detail)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_conversation():
|
||||
return StorageConversation("conv1", chat_mode="chat_normal", user_name="user1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conversation_with_messages():
|
||||
conv = OnceConversation(chat_mode="chat_normal", user_name="user1")
|
||||
conv.start_new_round()
|
||||
conv.add_user_message("Hello")
|
||||
conv.add_ai_message("Hi")
|
||||
conv.end_current_round()
|
||||
|
||||
conv.start_new_round()
|
||||
conv.add_user_message("How are you?")
|
||||
conv.add_ai_message("I'm good, thanks")
|
||||
conv.end_current_round()
|
||||
|
||||
return conv
|
||||
|
||||
|
||||
def test_init(basic_conversation):
|
||||
assert basic_conversation.chat_mode == "chat_normal"
|
||||
assert basic_conversation.user_name == "user1"
|
||||
assert basic_conversation.sys_code == "sys1"
|
||||
assert basic_conversation.messages == []
|
||||
assert basic_conversation.start_date == ""
|
||||
assert basic_conversation.chat_order == 0
|
||||
assert basic_conversation.model_name == ""
|
||||
assert basic_conversation.param_type == ""
|
||||
assert basic_conversation.param_value == ""
|
||||
assert basic_conversation.cost == 0
|
||||
assert basic_conversation.tokens == 0
|
||||
assert basic_conversation._message_index == 0
|
||||
|
||||
|
||||
def test_add_user_message(basic_conversation, human_message):
|
||||
basic_conversation.add_user_message(human_message.content)
|
||||
assert len(basic_conversation.messages) == 1
|
||||
assert isinstance(basic_conversation.messages[0], HumanMessage)
|
||||
|
||||
|
||||
def test_add_ai_message(basic_conversation, ai_message):
|
||||
basic_conversation.add_ai_message(ai_message.content)
|
||||
assert len(basic_conversation.messages) == 1
|
||||
assert isinstance(basic_conversation.messages[0], AIMessage)
|
||||
|
||||
|
||||
def test_add_system_message(basic_conversation, system_message):
|
||||
basic_conversation.add_system_message(system_message.content)
|
||||
assert len(basic_conversation.messages) == 1
|
||||
assert isinstance(basic_conversation.messages[0], SystemMessage)
|
||||
|
||||
|
||||
def test_add_view_message(basic_conversation, view_message):
|
||||
basic_conversation.add_view_message(view_message.content)
|
||||
assert len(basic_conversation.messages) == 1
|
||||
assert isinstance(basic_conversation.messages[0], ViewMessage)
|
||||
|
||||
|
||||
def test_set_start_time(basic_conversation):
|
||||
now = datetime.now()
|
||||
basic_conversation.set_start_time(now)
|
||||
assert basic_conversation.start_date == now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
def test_clear_messages(basic_conversation, human_message):
|
||||
basic_conversation.add_user_message(human_message.content)
|
||||
basic_conversation.clear()
|
||||
assert len(basic_conversation.messages) == 0
|
||||
|
||||
|
||||
def test_get_latest_user_message(basic_conversation, human_message):
|
||||
basic_conversation.add_user_message(human_message.content)
|
||||
latest_message = basic_conversation.get_latest_user_message()
|
||||
assert latest_message == human_message
|
||||
|
||||
|
||||
def test_get_system_messages(basic_conversation, system_message):
|
||||
basic_conversation.add_system_message(system_message.content)
|
||||
system_messages = basic_conversation.get_system_messages()
|
||||
assert len(system_messages) == 1
|
||||
assert system_messages[0] == system_message
|
||||
|
||||
|
||||
def test_from_conversation(basic_conversation):
|
||||
new_conversation = OnceConversation(chat_mode="chat_advanced", user_name="user2")
|
||||
basic_conversation.from_conversation(new_conversation)
|
||||
assert basic_conversation.chat_mode == "chat_advanced"
|
||||
assert basic_conversation.user_name == "user2"
|
||||
|
||||
|
||||
def test_get_messages_by_round(conversation_with_messages):
|
||||
# Test first round
|
||||
round1_messages = conversation_with_messages.get_messages_by_round(1)
|
||||
assert len(round1_messages) == 2
|
||||
assert round1_messages[0].content == "Hello"
|
||||
assert round1_messages[1].content == "Hi"
|
||||
|
||||
# Test not existing round
|
||||
no_messages = conversation_with_messages.get_messages_by_round(3)
|
||||
assert len(no_messages) == 0
|
||||
|
||||
|
||||
def test_get_latest_round(conversation_with_messages):
|
||||
latest_round_messages = conversation_with_messages.get_latest_round()
|
||||
assert len(latest_round_messages) == 2
|
||||
assert latest_round_messages[0].content == "How are you?"
|
||||
assert latest_round_messages[1].content == "I'm good, thanks"
|
||||
|
||||
|
||||
def test_get_messages_with_round(conversation_with_messages):
|
||||
# Test last round
|
||||
last_round_messages = conversation_with_messages.get_messages_with_round(1)
|
||||
assert len(last_round_messages) == 2
|
||||
assert last_round_messages[0].content == "How are you?"
|
||||
assert last_round_messages[1].content == "I'm good, thanks"
|
||||
|
||||
# Test last two rounds
|
||||
last_two_rounds_messages = conversation_with_messages.get_messages_with_round(2)
|
||||
assert len(last_two_rounds_messages) == 4
|
||||
assert last_two_rounds_messages[0].content == "Hello"
|
||||
assert last_two_rounds_messages[1].content == "Hi"
|
||||
|
||||
|
||||
def test_get_model_messages(conversation_with_messages):
|
||||
model_messages = conversation_with_messages.get_model_messages()
|
||||
assert len(model_messages) == 4
|
||||
assert all(isinstance(msg, ModelMessage) for msg in model_messages)
|
||||
assert model_messages[0].content == "Hello"
|
||||
assert model_messages[1].content == "Hi"
|
||||
assert model_messages[2].content == "How are you?"
|
||||
assert model_messages[3].content == "I'm good, thanks"
|
||||
|
||||
|
||||
def test_conversation_identifier(conversation_identifier):
|
||||
assert conversation_identifier.conv_uid == "conv1"
|
||||
assert conversation_identifier.identifier_type == "conversation"
|
||||
assert conversation_identifier.str_identifier == "conversation:conv1"
|
||||
assert conversation_identifier.to_dict() == {
|
||||
"conv_uid": "conv1",
|
||||
"identifier_type": "conversation",
|
||||
}
|
||||
|
||||
|
||||
def test_message_identifier(message_identifier):
|
||||
assert message_identifier.conv_uid == "conv1"
|
||||
assert message_identifier.index == 1
|
||||
assert message_identifier.identifier_type == "message"
|
||||
assert message_identifier.str_identifier == "message___conv1___1"
|
||||
assert message_identifier.to_dict() == {
|
||||
"conv_uid": "conv1",
|
||||
"index": 1,
|
||||
"identifier_type": "message",
|
||||
}
|
||||
|
||||
|
||||
def test_message_storage_item(message_storage_item):
|
||||
assert message_storage_item.conv_uid == "conv1"
|
||||
assert message_storage_item.index == 1
|
||||
assert message_storage_item.message_detail == {
|
||||
"type": "human",
|
||||
"data": {
|
||||
"content": "Hello",
|
||||
"index": 1,
|
||||
"round_index": 0,
|
||||
"additional_kwargs": {},
|
||||
"example": False,
|
||||
},
|
||||
"index": 1,
|
||||
"round_index": 0,
|
||||
}
|
||||
|
||||
assert isinstance(message_storage_item.identifier, MessageIdentifier)
|
||||
assert message_storage_item.to_dict() == {
|
||||
"conv_uid": "conv1",
|
||||
"index": 1,
|
||||
"message_detail": {
|
||||
"type": "human",
|
||||
"index": 1,
|
||||
"data": {
|
||||
"content": "Hello",
|
||||
"index": 1,
|
||||
"round_index": 0,
|
||||
"additional_kwargs": {},
|
||||
"example": False,
|
||||
},
|
||||
"round_index": 0,
|
||||
},
|
||||
}
|
||||
|
||||
assert isinstance(message_storage_item.to_message(), BaseMessage)
|
||||
|
||||
|
||||
def test_storage_conversation_init(storage_conversation):
|
||||
assert storage_conversation.conv_uid == "conv1"
|
||||
assert storage_conversation.chat_mode == "chat_normal"
|
||||
assert storage_conversation.user_name == "user1"
|
||||
|
||||
|
||||
def test_storage_conversation_add_user_message(storage_conversation):
|
||||
storage_conversation.add_user_message("Hi")
|
||||
assert len(storage_conversation.messages) == 1
|
||||
assert isinstance(storage_conversation.messages[0], HumanMessage)
|
||||
|
||||
|
||||
def test_storage_conversation_add_ai_message(storage_conversation):
|
||||
storage_conversation.add_ai_message("Hello")
|
||||
assert len(storage_conversation.messages) == 1
|
||||
assert isinstance(storage_conversation.messages[0], AIMessage)
|
||||
|
||||
|
||||
def test_save_to_storage(storage_conversation, in_memory_storage):
|
||||
# Set storage
|
||||
storage_conversation.conv_storage = in_memory_storage
|
||||
storage_conversation.message_storage = in_memory_storage
|
||||
|
||||
# Add messages
|
||||
storage_conversation.add_user_message("User message")
|
||||
storage_conversation.add_ai_message("AI response")
|
||||
|
||||
# Save to storage
|
||||
storage_conversation.save_to_storage()
|
||||
|
||||
# Create a new StorageConversation instance to load the data
|
||||
saved_conversation = StorageConversation(
|
||||
storage_conversation.conv_uid,
|
||||
conv_storage=in_memory_storage,
|
||||
message_storage=in_memory_storage,
|
||||
)
|
||||
|
||||
assert saved_conversation.conv_uid == storage_conversation.conv_uid
|
||||
assert len(saved_conversation.messages) == 2
|
||||
assert isinstance(saved_conversation.messages[0], HumanMessage)
|
||||
assert isinstance(saved_conversation.messages[1], AIMessage)
|
||||
|
||||
|
||||
def test_load_from_storage(storage_conversation, in_memory_storage):
|
||||
# Set storage
|
||||
storage_conversation.conv_storage = in_memory_storage
|
||||
storage_conversation.message_storage = in_memory_storage
|
||||
|
||||
# Add messages and save to storage
|
||||
storage_conversation.add_user_message("User message")
|
||||
storage_conversation.add_ai_message("AI response")
|
||||
storage_conversation.save_to_storage()
|
||||
|
||||
# Create a new StorageConversation instance to load the data
|
||||
new_conversation = StorageConversation(
|
||||
"conv1", conv_storage=in_memory_storage, message_storage=in_memory_storage
|
||||
)
|
||||
|
||||
# Check if the data is loaded correctly
|
||||
assert new_conversation.conv_uid == storage_conversation.conv_uid
|
||||
assert len(new_conversation.messages) == 2
|
||||
assert new_conversation.messages[0].content == "User message"
|
||||
assert new_conversation.messages[1].content == "AI response"
|
||||
assert isinstance(new_conversation.messages[0], HumanMessage)
|
||||
assert isinstance(new_conversation.messages[1], AIMessage)
|
129
dbgpt/core/interface/tests/test_storage.py
Normal file
129
dbgpt/core/interface/tests/test_storage.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import pytest
|
||||
from typing import Dict, Type, Union
|
||||
from dbgpt.core.interface.storage import (
|
||||
ResourceIdentifier,
|
||||
StorageError,
|
||||
QuerySpec,
|
||||
InMemoryStorage,
|
||||
StorageItem,
|
||||
)
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
|
||||
class MockResourceIdentifier(ResourceIdentifier):
|
||||
def __init__(self, identifier: str):
|
||||
self._identifier = identifier
|
||||
|
||||
@property
|
||||
def str_identifier(self) -> str:
|
||||
return self._identifier
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {"identifier": self._identifier}
|
||||
|
||||
|
||||
class MockStorageItem(StorageItem):
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
if not isinstance(other, MockStorageItem):
|
||||
raise ValueError("other must be a MockStorageItem")
|
||||
self.data = other.data
|
||||
|
||||
def __init__(self, identifier: Union[str, MockResourceIdentifier], data):
|
||||
self._identifier_str = (
|
||||
identifier if isinstance(identifier, str) else identifier.str_identifier
|
||||
)
|
||||
self.data = data
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {"identifier": self._identifier_str, "data": self.data}
|
||||
|
||||
@property
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
return MockResourceIdentifier(self._identifier_str)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serializer():
|
||||
return JsonSerializer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def in_memory_storage(serializer):
|
||||
return InMemoryStorage(serializer)
|
||||
|
||||
|
||||
def test_save_and_load(in_memory_storage):
|
||||
resource_id = MockResourceIdentifier("1")
|
||||
item = MockStorageItem(resource_id, "test_data")
|
||||
|
||||
in_memory_storage.save(item)
|
||||
|
||||
loaded_item = in_memory_storage.load(resource_id, MockStorageItem)
|
||||
assert loaded_item.data == "test_data"
|
||||
|
||||
|
||||
def test_duplicate_save(in_memory_storage):
|
||||
item = MockStorageItem("1", "test_data")
|
||||
|
||||
in_memory_storage.save(item)
|
||||
|
||||
# Should raise StorageError when saving the same data
|
||||
with pytest.raises(StorageError):
|
||||
in_memory_storage.save(item)
|
||||
|
||||
|
||||
def test_delete(in_memory_storage):
|
||||
resource_id = MockResourceIdentifier("1")
|
||||
item = MockStorageItem(resource_id, "test_data")
|
||||
|
||||
in_memory_storage.save(item)
|
||||
in_memory_storage.delete(resource_id)
|
||||
# Storage should not contain the data after deletion
|
||||
assert in_memory_storage.load(resource_id, MockStorageItem) is None
|
||||
|
||||
|
||||
def test_query(in_memory_storage):
|
||||
resource_id1 = MockResourceIdentifier("1")
|
||||
item1 = MockStorageItem(resource_id1, "test_data1")
|
||||
|
||||
resource_id2 = MockResourceIdentifier("2")
|
||||
item2 = MockStorageItem(resource_id2, "test_data2")
|
||||
|
||||
in_memory_storage.save(item1)
|
||||
in_memory_storage.save(item2)
|
||||
|
||||
query_spec = QuerySpec(conditions={"data": "test_data1"})
|
||||
results = in_memory_storage.query(query_spec, MockStorageItem)
|
||||
assert len(results) == 1
|
||||
assert results[0].data == "test_data1"
|
||||
|
||||
|
||||
def test_count(in_memory_storage):
|
||||
item1 = MockStorageItem("1", "test_data1")
|
||||
|
||||
item2 = MockStorageItem("2", "test_data2")
|
||||
|
||||
in_memory_storage.save(item1)
|
||||
in_memory_storage.save(item2)
|
||||
|
||||
query_spec = QuerySpec(conditions={})
|
||||
count = in_memory_storage.count(query_spec, MockStorageItem)
|
||||
assert count == 2
|
||||
|
||||
|
||||
def test_paginate_query(in_memory_storage):
|
||||
for i in range(10):
|
||||
resource_id = MockResourceIdentifier(str(i))
|
||||
item = MockStorageItem(resource_id, f"test_data{i}")
|
||||
in_memory_storage.save(item)
|
||||
|
||||
page_size = 3
|
||||
query_spec = QuerySpec(conditions={})
|
||||
page_result = in_memory_storage.paginate_query(
|
||||
2, page_size, MockStorageItem, query_spec
|
||||
)
|
||||
|
||||
assert len(page_result.items) == page_size
|
||||
assert page_result.total_count == 10
|
||||
assert page_result.total_pages == 4
|
||||
assert page_result.page == 2
|
Reference in New Issue
Block a user