mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-10-22 17:39:02 +00:00
refactor: Refactor storage system (#937)
This commit is contained in:
409
dbgpt/core/interface/storage.py
Normal file
409
dbgpt/core/interface/storage.py
Normal file
@@ -0,0 +1,409 @@
|
||||
from typing import Generic, TypeVar, Type, Optional, Dict, Any, List
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from dbgpt.core.interface.serialization import Serializable, Serializer
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
from dbgpt.util.annotations import PublicAPI
|
||||
from dbgpt.util.pagination_utils import PaginationResult
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class ResourceIdentifier(Serializable, ABC):
|
||||
"""The resource identifier interface for resource identifiers."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def str_identifier(self) -> str:
|
||||
"""Get the string identifier of the resource.
|
||||
|
||||
The string identifier is used to uniquely identify the resource.
|
||||
|
||||
Returns:
|
||||
str: The string identifier of the resource
|
||||
"""
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of the key."""
|
||||
return hash(self.str_identifier)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality with another key."""
|
||||
if not isinstance(other, ResourceIdentifier):
|
||||
return False
|
||||
return self.str_identifier == other.str_identifier
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class StorageItem(Serializable, ABC):
|
||||
"""The storage item interface for storage items."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def identifier(self) -> ResourceIdentifier:
|
||||
"""Get the resource identifier of the storage item.
|
||||
|
||||
Returns:
|
||||
ResourceIdentifier: The resource identifier of the storage item
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def merge(self, other: "StorageItem") -> None:
|
||||
"""Merge the other storage item into the current storage item.
|
||||
|
||||
Args:
|
||||
other (StorageItem): The other storage item
|
||||
"""
|
||||
|
||||
|
||||
T = TypeVar("T", bound=StorageItem)
|
||||
TDataRepresentation = TypeVar("TDataRepresentation")
|
||||
|
||||
|
||||
class StorageItemAdapter(Generic[T, TDataRepresentation]):
|
||||
"""The storage item adapter for converting storage items to and from the storage format.
|
||||
|
||||
Sometimes, the storage item is not the same as the storage format,
|
||||
so we need to convert the storage item to the storage format and vice versa.
|
||||
|
||||
In database storage, the storage format is database model, but the StorageItem is the user-defined object.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def to_storage_format(self, item: T) -> TDataRepresentation:
|
||||
"""Convert the storage item to the storage format.
|
||||
|
||||
Args:
|
||||
item (T): The storage item
|
||||
|
||||
Returns:
|
||||
TDataRepresentation: The data in the storage format
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def from_storage_format(self, data: TDataRepresentation) -> T:
|
||||
"""Convert the storage format to the storage item.
|
||||
|
||||
Args:
|
||||
data (TDataRepresentation): The data in the storage format
|
||||
|
||||
Returns:
|
||||
T: The storage item
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_query_for_identifier(
|
||||
self,
|
||||
storage_format: Type[TDataRepresentation],
|
||||
resource_id: ResourceIdentifier,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Get the query for the resource identifier.
|
||||
|
||||
Args:
|
||||
storage_format (Type[TDataRepresentation]): The storage format
|
||||
resource_id (ResourceIdentifier): The resource identifier
|
||||
kwargs: The additional arguments
|
||||
|
||||
Returns:
|
||||
Any: The query for the resource identifier
|
||||
"""
|
||||
|
||||
|
||||
class DefaultStorageItemAdapter(StorageItemAdapter[T, T]):
|
||||
"""The default storage item adapter for converting storage items to and from the storage format.
|
||||
|
||||
The storage item is the same as the storage format, so no conversion is required.
|
||||
"""
|
||||
|
||||
def to_storage_format(self, item: T) -> T:
|
||||
return item
|
||||
|
||||
def from_storage_format(self, data: T) -> T:
|
||||
return data
|
||||
|
||||
def get_query_for_identifier(
|
||||
self, storage_format: Type[T], resource_id: ResourceIdentifier, **kwargs
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class StorageError(Exception):
|
||||
"""The base exception class for storage errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class QuerySpec:
|
||||
"""The query specification for querying data from the storage.
|
||||
|
||||
Attributes:
|
||||
conditions (Dict[str, Any]): The conditions for querying data
|
||||
limit (int): The maximum number of data to return
|
||||
offset (int): The offset of the data to return
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, conditions: Dict[str, Any], limit: int = None, offset: int = 0
|
||||
) -> None:
|
||||
self.conditions = conditions
|
||||
self.limit = limit
|
||||
self.offset = offset
|
||||
|
||||
|
||||
@PublicAPI(stability="beta")
|
||||
class StorageInterface(Generic[T, TDataRepresentation], ABC):
|
||||
"""The storage interface for storing and loading data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serializer: Optional[Serializer] = None,
|
||||
adapter: Optional[StorageItemAdapter[T, TDataRepresentation]] = None,
|
||||
):
|
||||
self._serializer = serializer or JsonSerializer()
|
||||
self._storage_item_adapter = adapter or DefaultStorageItemAdapter()
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get the serializer of the storage.
|
||||
|
||||
Returns:
|
||||
Serializer: The serializer of the storage
|
||||
"""
|
||||
return self._serializer
|
||||
|
||||
@property
|
||||
def adapter(self) -> StorageItemAdapter[T, TDataRepresentation]:
|
||||
"""Get the adapter of the storage.
|
||||
|
||||
Returns:
|
||||
StorageItemAdapter[T, TDataRepresentation]: The adapter of the storage
|
||||
"""
|
||||
return self._storage_item_adapter
|
||||
|
||||
@abstractmethod
|
||||
def save(self, data: T) -> None:
|
||||
"""Save the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If the data already exists in the storage or data is None
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def update(self, data: T) -> None:
|
||||
"""Update the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If data is None
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_or_update(self, data: T) -> None:
|
||||
"""Save or update the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If data is None
|
||||
"""
|
||||
|
||||
def save_list(self, data: List[T]) -> None:
|
||||
"""Save the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
|
||||
Raises:
|
||||
StorageError: If the data already exists in the storage or data is None
|
||||
"""
|
||||
for d in data:
|
||||
self.save(d)
|
||||
|
||||
def save_or_update_list(self, data: List[T]) -> None:
|
||||
"""Save or update the data to the storage.
|
||||
|
||||
Args:
|
||||
data (T): The data to save
|
||||
"""
|
||||
for d in data:
|
||||
self.save_or_update(d)
|
||||
|
||||
@abstractmethod
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
"""Load the data from the storage.
|
||||
|
||||
None will be returned if the data does not exist in the storage.
|
||||
|
||||
Load data with resource_id will be faster than query data with conditions,
|
||||
so we suggest to use load if possible.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
Optional[T]: The loaded data
|
||||
"""
|
||||
|
||||
def load_list(self, resource_id: List[ResourceIdentifier], cls: Type[T]) -> List[T]:
|
||||
"""Load the data from the storage.
|
||||
|
||||
None will be returned if the data does not exist in the storage.
|
||||
|
||||
Load data with resource_id will be faster than query data with conditions,
|
||||
so we suggest to use load if possible.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
Optional[T]: The loaded data
|
||||
"""
|
||||
result = []
|
||||
for r in resource_id:
|
||||
item = self.load(r, cls)
|
||||
if item is not None:
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
"""Delete the data from the storage.
|
||||
|
||||
Args:
|
||||
resource_id (ResourceIdentifier): The resource identifier of the data
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
"""Query data from the storage.
|
||||
|
||||
Query data with resource_id will be faster than query data with conditions, so please use load if possible.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
List[T]: The queried data
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||
"""Count the number of data from the storage.
|
||||
|
||||
Args:
|
||||
spec (QuerySpec): The query specification
|
||||
cls (Type[T]): The type of the data
|
||||
|
||||
Returns:
|
||||
int: The number of data
|
||||
"""
|
||||
|
||||
def paginate_query(
|
||||
self, page: int, page_size: int, cls: Type[T], spec: Optional[QuerySpec] = None
|
||||
) -> PaginationResult[T]:
|
||||
"""Paginate the query result.
|
||||
|
||||
Args:
|
||||
page (int): The page number
|
||||
page_size (int): The number of items per page
|
||||
cls (Type[T]): The type of the data
|
||||
spec (Optional[QuerySpec], optional): The query specification. Defaults to None.
|
||||
|
||||
Returns:
|
||||
PaginationResult[T]: The pagination result
|
||||
"""
|
||||
if spec is None:
|
||||
spec = QuerySpec(conditions={})
|
||||
spec.limit = page_size
|
||||
spec.offset = (page - 1) * page_size
|
||||
items = self.query(spec, cls)
|
||||
total = self.count(spec, cls)
|
||||
return PaginationResult(
|
||||
items=items,
|
||||
total_count=total,
|
||||
total_pages=(total + page_size - 1) // page_size,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
|
||||
@PublicAPI(stability="alpha")
|
||||
class InMemoryStorage(StorageInterface[T, T]):
|
||||
"""The in-memory storage for storing and loading data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
serializer: Optional[Serializer] = None,
|
||||
):
|
||||
super().__init__(serializer)
|
||||
self._data = {} # Key: ResourceIdentifier, Value: Serialized data
|
||||
|
||||
def save(self, data: T) -> None:
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
|
||||
if data.identifier.str_identifier in self._data:
|
||||
raise StorageError(
|
||||
f"Data with identifier {data.identifier.str_identifier} already exists"
|
||||
)
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
def update(self, data: T) -> None:
|
||||
if not data:
|
||||
raise StorageError("Data cannot be None")
|
||||
if not data.serializer:
|
||||
data.set_serializer(self.serializer)
|
||||
self._data[data.identifier.str_identifier] = data.serialize()
|
||||
|
||||
def save_or_update(self, data: T) -> None:
|
||||
self.update(data)
|
||||
|
||||
def load(self, resource_id: ResourceIdentifier, cls: Type[T]) -> Optional[T]:
|
||||
serialized_data = self._data.get(resource_id.str_identifier)
|
||||
if serialized_data is None:
|
||||
return None
|
||||
return self.serializer.deserialize(serialized_data, cls)
|
||||
|
||||
def delete(self, resource_id: ResourceIdentifier) -> None:
|
||||
if resource_id.str_identifier in self._data:
|
||||
del self._data[resource_id.str_identifier]
|
||||
|
||||
def query(self, spec: QuerySpec, cls: Type[T]) -> List[T]:
|
||||
result = []
|
||||
for serialized_data in self._data.values():
|
||||
data = self._serializer.deserialize(serialized_data, cls)
|
||||
if all(
|
||||
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||
):
|
||||
result.append(data)
|
||||
|
||||
# Apply limit and offset
|
||||
if spec.limit is not None:
|
||||
result = result[spec.offset : spec.offset + spec.limit]
|
||||
else:
|
||||
result = result[spec.offset :]
|
||||
return result
|
||||
|
||||
def count(self, spec: QuerySpec, cls: Type[T]) -> int:
|
||||
count = 0
|
||||
for serialized_data in self._data.values():
|
||||
data = self._serializer.deserialize(serialized_data, cls)
|
||||
if all(
|
||||
getattr(data, key) == value for key, value in spec.conditions.items()
|
||||
):
|
||||
count += 1
|
||||
return count
|
Reference in New Issue
Block a user