Files
DB-GPT/dbgpt/agent/memory/gpts_memory_storage.py
2024-02-20 14:28:58 +08:00

444 lines
14 KiB
Python

import dataclasses
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
from dbgpt.agent.common.schema import Status
from dbgpt.core.interface.storage import (
InMemoryStorage,
QuerySpec,
ResourceIdentifier,
StorageInterface,
StorageItem,
)
from .base import GptsMessage, GptsMessageMemory
@dataclass
class GptsPlanIdentifier(ResourceIdentifier):
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
conv_id: str
sub_task_num: Optional[str]
def __post_init__(self):
if self.conv_id is None or self.sub_task_num is None:
raise ValueError("conv_id and sub_task_num cannot be None")
if any(
self.identifier_split in key
for key in [
self.conv_id,
self.sub_task_num,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in conv_id, sub_task_num"
)
@property
def str_identifier(self) -> str:
return self.identifier_split.join(
key
for key in [
self.conv_id,
self.sub_task_num,
]
if key is not None
)
def to_dict(self) -> Dict:
return {
"conv_id": self.conv_id,
"sub_task_num": self.sub_task_num,
}
@dataclass
class GptsPlanStorage(StorageItem):
"""Gpts plan"""
conv_id: str
sub_task_num: int
sub_task_content: Optional[str]
sub_task_title: Optional[str] = None
sub_task_agent: Optional[str] = None
resource_name: Optional[str] = None
rely: Optional[str] = None
agent_model: Optional[str] = None
retry_times: Optional[int] = 0
max_retry_times: Optional[int] = 5
state: Optional[str] = Status.TODO.value
result: Optional[str] = None
_identifier: GptsPlanIdentifier = dataclasses.field(init=False)
@staticmethod
def from_dict(d: Dict[str, Any]):
return GptsPlanStorage(
conv_id=d.get("conv_id"),
sub_task_num=d["sub_task_num"],
sub_task_content=d["sub_task_content"],
sub_task_agent=d["sub_task_agent"],
resource_name=d["resource_name"],
rely=d["rely"],
agent_model=d["agent_model"],
retry_times=d["retry_times"],
max_retry_times=d["max_retry_times"],
state=d["state"],
result=d["result"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def _check(self):
if self.conv_id is None:
raise ValueError("conv_id cannot be None")
if self.sub_task_num is None:
raise ValueError("sub_task_num cannot be None")
if self.sub_task_content is None:
raise ValueError("sub_task_content cannot be None")
if self.state is None:
raise ValueError("state cannot be None")
@property
def identifier(self) -> GptsPlanIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the other item into the current item.
Args:
other (StorageItem): The other item to merge
"""
if not isinstance(other, GptsPlanStorage):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
)
self.from_object(other)
@dataclass
class GptsMessageIdentifier(ResourceIdentifier):
identifier_split: str = dataclasses.field(default="___$$$$___", init=False)
conv_id: str
sender: Optional[str]
receiver: Optional[str]
rounds: Optional[int]
def __post_init__(self):
if (
self.conv_id is None
or self.sender is None
or self.receiver is None
or self.rounds is None
):
raise ValueError("conv_id and sub_task_num cannot be None")
if any(
self.identifier_split in key
for key in [
self.conv_id,
self.sender,
self.receiver,
self.rounds,
]
if key is not None
):
raise ValueError(
f"identifier_split {self.identifier_split} is not allowed in conv_id, sender, receiver, rounds"
)
@property
def str_identifier(self) -> str:
return self.identifier_split.join(
key
for key in [
self.conv_id,
self.sender,
self.receiver,
self.rounds,
]
if key is not None
)
def to_dict(self) -> Dict:
return {
"conv_id": self.conv_id,
"sender": self.sender,
"receiver": self.receiver,
"rounds": self.rounds,
}
@dataclass
class GptsMessageStorage(StorageItem):
"""Gpts Message"""
conv_id: str
sender: str
receiver: str
role: str
content: str
rounds: Optional[int]
current_goal: str = None
context: Optional[str] = None
review_info: Optional[str] = None
action_report: Optional[str] = None
model_name: Optional[str] = None
created_at: datetime = datetime.utcnow
updated_at: datetime = datetime.utcnow
_identifier: GptsMessageIdentifier = dataclasses.field(init=False)
@staticmethod
def from_dict(d: Dict[str, Any]):
return GptsMessageStorage(
conv_id=d["conv_id"],
sender=d["sender"],
receiver=d["receiver"],
role=d["role"],
content=d["content"],
rounds=d["rounds"],
model_name=d["model_name"],
current_goal=d["current_goal"],
context=d["context"],
review_info=d["review_info"],
action_report=d["action_report"],
created_at=d["created_at"],
updated_at=d["updated_at"],
)
def to_dict(self) -> Dict[str, Any]:
return dataclasses.asdict(self)
def _check(self):
if self.conv_id is None:
raise ValueError("conv_id cannot be None")
if self.sub_task_num is None:
raise ValueError("sub_task_num cannot be None")
if self.sub_task_content is None:
raise ValueError("sub_task_content cannot be None")
if self.state is None:
raise ValueError("state cannot be None")
def to_gpts_message(self) -> GptsMessage:
"""Convert the storage to a GptsMessage."""
input_variables = (
None
if not self.input_variables
else self.input_variables.strip().split(",")
)
return GptsMessage(
conv_id=self.conv_id,
sender=self.sender,
receiver=self.receiver,
role=self.role,
content=self.content,
rounds=self.rounds,
current_goal=self.current_goal,
context=self.context,
review_info=self.review_info,
action_report=self.action_report,
model_name=self.model_name,
created_at=self.created_at,
updated_at=self.updated_at,
)
@staticmethod
def from_gpts_message(gpts_message: GptsMessage) -> "StoragePromptTemplate":
"""Convert a GptsMessage to a storage e."""
return GptsMessageStorage(
conv_id=gpts_message.conv_id,
sender=gpts_message.sender,
receiver=gpts_message.receiver,
role=gpts_message.role,
content=gpts_message.content,
rounds=gpts_message.rounds,
current_goal=gpts_message.current_goal,
context=gpts_message.context,
review_info=gpts_message.review_info,
action_report=gpts_message.action_report,
model_name=gpts_message.model_name,
created_at=gpts_message.created_at,
updated_at=gpts_message.updated_at,
)
@property
def identifier(self) -> GptsMessageIdentifier:
return self._identifier
def merge(self, other: "StorageItem") -> None:
"""Merge the other item into the current item.
Args:
other (StorageItem): The other item to merge
"""
if not isinstance(other, GptsMessageStorage):
raise ValueError(
f"Cannot merge {type(other)} into {type(self)} because they are not the same type."
)
self.from_object(other)
class GptsMessageManager(GptsMessageMemory):
"""The manager class for GptsMessage.
Simple wrapper for the storage interface.
TODO: Import gpts storage with storage interface.
"""
def __init__(self, storage: Optional[StorageInterface[GptsMessage, Any]] = None):
if storage is None:
storage = InMemoryStorage()
self._storage = storage
@property
def storage(self) -> StorageInterface[GptsMessage, Any]:
"""The storage interface for prompt templates."""
return self._storage
def append(self, message: GptsMessage):
self.storage.save(GptsMessageStorage.from_gpts_message(message))
def get_by_agent(self, conv_id: str, agent: str) -> Optional[List[GptsMessage]]:
query_spec = QuerySpec(
conditions={
"conv_id": conv_id,
"sys_code": sys_code,
**kwargs,
}
)
queries: List[GptsMessageStorage] = self.storage.query(
query_spec, GptsMessageStorage
)
if not queries:
return []
if prefer_prompt_language:
prefer_prompt_language = prefer_prompt_language.lower()
temp_queries = [
query
for query in queries
if query.prompt_language
and query.prompt_language.lower() == prefer_prompt_language
]
if temp_queries:
queries = temp_queries
if prefer_model:
prefer_model = prefer_model.lower()
temp_queries = [
query
for query in queries
if query.model and query.model.lower() == prefer_model
]
if temp_queries:
queries = temp_queries
return queries
def get_between_agents(
self,
conv_id: str,
agent1: str,
agent2: str,
current_goal: Optional[str] = None,
) -> Optional[List[GptsMessage]]:
return super().get_between_agents(conv_id, agent1, agent2, current_goal)
def get_by_conv_id(self, conv_id: str) -> Optional[List[GptsMessage]]:
return super().get_by_conv_id(conv_id)
def get_last_message(self, conv_id: str) -> Optional[GptsMessage]:
return super().get_last_message(conv_id)
def prefer_query(
self,
prompt_name: str,
sys_code: Optional[str] = None,
prefer_prompt_language: Optional[str] = None,
prefer_model: Optional[str] = None,
**kwargs,
) -> List[GptsMessage]:
"""Query prompt templates from storage with prefer params.
Sometimes, we want to query prompt templates with prefer params(e.g. some language or some model).
This method will query prompt templates with prefer params first, if not found, will query all prompt templates.
Examples:
Query a prompt template.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query("hello")
Query with sys_code and username.
.. code-block:: python
prompt_template_list = prompt_manager.prefer_query(
"hello", sys_code="sys_code", user_name="user_name"
)
Query with prefer prompt language.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with prompt language "zh-cn", if not found, will return all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_prompt_language="zh-cn"
)
Query with prefer model.
.. code-block:: python
# First query with prompt name "hello" exactly.
# Second filter with model "vicuna-13b-v1.5", if not found, will return all prompt templates.
prompt_template_list = prompt_manager.prefer_query(
"hello", prefer_model="vicuna-13b-v1.5"
)
Args:
prompt_name (str): The name of the prompt template.
sys_code (Optional[str], optional): The system code of the prompt template. Defaults to None.
prefer_prompt_language (Optional[str], optional): The language of the prompt template. Defaults to None.
prefer_model (Optional[str], optional): The model of the prompt template. Defaults to None.
kwargs (Dict): Other query params(If some key and value not None, wo we query it exactly).
"""
query_spec = QuerySpec(
conditions={
"prompt_name": prompt_name,
"sys_code": sys_code,
**kwargs,
}
)
queries: List[StoragePromptTemplate] = self.storage.query(
query_spec, StoragePromptTemplate
)
if not queries:
return []
if prefer_prompt_language:
prefer_prompt_language = prefer_prompt_language.lower()
temp_queries = [
query
for query in queries
if query.prompt_language
and query.prompt_language.lower() == prefer_prompt_language
]
if temp_queries:
queries = temp_queries
if prefer_model:
prefer_model = prefer_model.lower()
temp_queries = [
query
for query in queries
if query.model and query.model.lower() == prefer_model
]
if temp_queries:
queries = temp_queries
return queries