Files
DB-GPT/dbgpt/core/interface/storage.py
2024-01-10 10:39:04 +08:00

419 lines
13 KiB
Python

from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from dbgpt.core.interface.serialization import Serializable, Serializer
from dbgpt.util.annotations import PublicAPI
from dbgpt.util.pagination_utils import PaginationResult
from dbgpt.util.serialization.json_serialization import JsonSerializer
@PublicAPI(stability="beta")
class ResourceIdentifier(Serializable, ABC):
"""The resource identifier interface for resource identifiers."""
@property
@abstractmethod
def str_identifier(self) -> str:
"""Get the string identifier of the resource.
The string identifier is used to uniquely identify the resource.
Returns:
str: The string identifier of the resource
"""
def __hash__(self) -> int:
"""Return the hash value of the key."""
return hash(self.str_identifier)
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
if not isinstance(other, ResourceIdentifier):
return False
return self.str_identifier == other.str_identifier
@PublicAPI(stability="beta")
class StorageItem(Serializable, ABC):
"""The storage item interface for storage items."""
@property
@abstractmethod
def identifier(self) -> ResourceIdentifier:
"""Get the resource identifier of the storage item.
Returns:
ResourceIdentifier: The resource identifier of the storage item
"""
@abstractmethod
def merge(self, other: "StorageItem") -> None:
"""Merge the other storage item into the current storage item.
Args:
other (StorageItem): The other storage item
"""
T = TypeVar("T", bound=StorageItem)
TDataRepresentation = TypeVar("TDataRepresentation")
class StorageItemAdapter(Generic[T, TDataRepresentation]):
"""The storage item adapter for converting storage items to and from the storage format.
Sometimes, the storage item is not the same as the storage format,
so we need to convert the storage item to the storage format and vice versa.
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
"""
@abstractmethod
def to_storage_format(self, item: T) -> TDataRepresentation:
"""Convert the storage item to the storage format.
Args:
item (T): The storage item
Returns:
TDataRepresentation: The data in the storage format
"""
@abstractmethod
def from_storage_format(self, data: TDataRepresentation) -> T:
"""Convert the storage format to the storage item.
Args:
data (TDataRepresentation): The data in the storage format
Returns:
T: The storage item
"""
@abstractmethod
def get_query_for_identifier(
self,
storage_format: Type[TDataRepresentation],
resource_id: ResourceIdentifier,
**kwargs,
) -> Any:
"""Get the query for the resource identifier.
Args:
storage_format (Type[TDataRepresentation]): The storage format
resource_id (ResourceIdentifier): The resource identifier
kwargs: The additional arguments
Returns:
Any: The query for the resource identifier
"""
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
"""The default storage item adapter for converting storage items to and from the storage format.
The storage item is the same as the storage format, so no conversion is required.
"""
def to_storage_format(self, item: T) -> T:
return item
def from_storage_format(self, data: T) -> T:
return data
def get_query_for_identifier(
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
) -> bool:
return True
@PublicAPI(stability="beta")
class StorageError(Exception):
"""The base exception class for storage errors."""
def __init__(self, message: str):
super().__init__(message)
@PublicAPI(stability="beta")
class QuerySpec:
"""The query specification for querying data from the storage.
Attributes:
conditions (Dict[str, Any]): The conditions for querying data
limit (int): The maximum number of data to return
offset (int): The offset of the data to return
"""
def __init__(
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
) -> None:
self.conditions = conditions
self.limit = limit
self.offset = offset
@PublicAPI(stability="beta")
class StorageInterface(Generic[T, TDataRepresentation], ABC):
"""The storage interface for storing and loading data."""
def __init__(
self,
serializer: Optional[Serializer] = None,
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
):
self._serializer = serializer or JsonSerializer()
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
@property
def serializer(self) -> Serializer:
"""Get the serializer of the storage.
Returns:
Serializer: The serializer of the storage
"""
return self._serializer
@property
def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]:
"""Get the adapter of the storage.
Returns:
StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage
"""
return self._storage_item_adapter
@abstractmethod
def save(self, data: T) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If the data already exists in the storage or data is None
"""
@abstractmethod
def update(self, data: T) -> None:
"""Update the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If data is None
"""
@abstractmethod
def save_or_update(self, data: T) -> None:
"""Save or update the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If data is None
"""
def save_list(self, data: List[T]) -> None:
"""Save the data to the storage.
Args:
data (T): The data to save
Raises:
StorageError: If the data already exists in the storage or data is None
"""
for d in data:
self.save(d)
def save_or_update_list(self, data: List[T]) -> None:
"""Save or update the data to the storage.
Args:
data (T): The data to save
"""
for d in data:
self.save_or_update(d)
@abstractmethod
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
Load data with resource_id will be faster than query data with conditions,
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
"""Load the data from the storage.
None will be returned if the data does not exist in the storage.
Load data with resource_id will be faster than query data with conditions,
so we suggest to use load if possible.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
cls (Type[T]): The type of the data
Returns:
Optional[T]: The loaded data
"""
result = []
for r in resource_id:
item = self.load(r, cls)
if item is not None:
result.append(item)
return result
@abstractmethod
def delete(self, resource_id: ResourceIdentifier) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
"""
def delete_list(self, resource_id: List[ResourceIdentifier]) -> None:
"""Delete the data from the storage.
Args:
resource_id (ResourceIdentifier): The resource identifier of the data
"""
for r in resource_id:
self.delete(r)
@abstractmethod
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
"""Query data from the storage.
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
List[T]: The queried data
"""
@abstractmethod
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
"""Count the number of data from the storage.
Args:
spec (QuerySpec): The query specification
cls (Type[T]): The type of the data
Returns:
int: The number of data
"""
def paginate_query(
self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None
) -> PaginationResult[T]:
"""Paginate the query result.
Args:
page (int): The page number
page_size (int): The number of items per page
cls (Type[T]): The type of the data
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
Returns:
PaginationResult[T]: The pagination result
"""
if spec is None:
spec = QuerySpec(conditions={})
spec.limit = page_size
spec.offset = (page - 1) * page_size
items = self.query(spec, cls)
total = self.count(spec, cls)
return PaginationResult(
items=items,
total_count=total,
total_pages=(total + page_size - 1) // page_size,
page=page,
page_size=page_size,
)
@PublicAPI(stability="alpha")
class InMemoryStorage(StorageInterface[T, T]):
"""The in-memory storage for storing and loading data."""
def __init__(
self,
serializer: Optional[Serializer] = None,
):
super().__init__(serializer)
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
def save(self, data: T) -> None:
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
data.set_serializer(self.serializer)
if data.identifier.str_identifier in self._data:
raise StorageError(
f"Data with identifier {data.identifier.str_identifier} already exists"
)
self._data[data.identifier.str_identifier] = data.serialize()
def update(self, data: T) -> None:
if not data:
raise StorageError("Data cannot be None")
if not data.serializer:
data.set_serializer(self.serializer)
self._data[data.identifier.str_identifier] = data.serialize()
def save_or_update(self, data: T) -> None:
self.update(data)
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
serialized_data = self._data.get(resource_id.str_identifier)
if serialized_data is None:
return None
return self.serializer.deserialize(serialized_data, cls)
def delete(self, resource_id: ResourceIdentifier) -> None:
if resource_id.str_identifier in self._data:
del self._data[resource_id.str_identifier]
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
result = []
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
result.append(data)
# Apply limit and offset
if spec.limit is not None:
result = result[spec.offset : spec.offset + spec.limit]
else:
result = result[spec.offset :]
return result
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
count = 0
for serialized_data in self._data.values():
data = self._serializer.deserialize(serialized_data, cls)
if all(
getattr(data, key) == value for key, value in spec.conditions.items()
):
count += 1
return count