chore: Add pylint for storage (#1298)

This commit is contained in:
Fangyin Cheng
2024-03-15 15:42:46 +08:00
committed by GitHub
parent a207640ff2
commit 8897d6e8fd
50 changed files with 784 additions and 667 deletions

View File

@@ -1,6 +1,7 @@
from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
"""Module for cache storage."""
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue # noqa: F401
from .manager import CacheManager, initialize_cache # noqa: F401
from .storage.base import MemoryCacheStorage # noqa: F401
__all__ = [
"LLMCacheKey",

View File

@@ -0,0 +1 @@
"""Embeddings cache."""

View File

@@ -1,21 +1,26 @@
"""Cache client for LLM."""
import hashlib
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput, Serializer
from dbgpt.core import ModelOutput
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
from dbgpt.model.base import ModelType
from dbgpt.storage.cache.manager import CacheManager
from .manager import CacheManager
@dataclass
class LLMCacheKeyData:
"""Cache key data for LLM."""
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
model_type: Optional[str] = ModelType.HF
# See dbgpt.model.base.ModelType
model_type: Optional[str] = "huggingface"
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@@ -23,12 +28,15 @@ CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
"""Cache value data for LLM."""
output: CacheOutputType
user: Optional[str] = None
_is_list: Optional[bool] = False
_is_list: bool = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
"""Create LLMCacheValueData object from dict."""
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
@@ -46,6 +54,7 @@ class LLMCacheValueData:
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
"""Convert to dict."""
output = self.output
is_list = False
if isinstance(output, list):
@@ -53,16 +62,18 @@ class LLMCacheValueData:
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list
output = output_list # type: ignore
else:
output = output.to_dict()
output = output.to_dict() # type: ignore
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
"""Return whether the output is a list."""
return self._is_list
def __str__(self) -> str:
"""Return string representation."""
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
@@ -70,74 +81,116 @@ class LLMCacheValueData:
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
"""Cache key for LLM."""
def __init__(self, **kwargs) -> None:
"""Create a new instance of LLMCacheKey."""
super().__init__()
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
"""Return the hash value of the object."""
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
"""Return the byte array of hash value.
Returns:
bytes: The byte array of hash value.
"""
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
"""Convert to dict."""
return asdict(self.config)
def get_value(self) -> LLMCacheKeyData:
"""Return the real object of current cache key."""
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
"""Cache value for LLM."""
def __init__(self, **kwargs) -> None:
"""Create a new instance of LLMCacheValue."""
super().__init__()
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
"""Convert to dict."""
return self.value.to_dict()
def get_value(self) -> LLMCacheValueData:
"""Return the underlying real value."""
return self.value
def __str__(self) -> str:
"""Return string representation."""
return f"value: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
"""Cache client for LLM."""
def __init__(self, cache_manager: CacheManager) -> None:
"""Create a new instance of LLMCacheClient."""
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
self,
key: LLMCacheKey, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> Optional[LLMCacheValue]:
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
"""Retrieve a value from the cache using the provided key.
Args:
key (LLMCacheKey): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[LLMCacheValue]: The value retrieved according to key. If cache key
not exist, return None.
"""
return cast(
LLMCacheValue,
await self._cache_manager.get(key, LLMCacheValue, cache_config),
)
async def set(
self,
key: LLMCacheKey,
value: LLMCacheValue,
key: LLMCacheKey, # type: ignore
value: LLMCacheValue, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
self,
key: LLMCacheKey, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> bool:
"""Check if a key exists in the cache."""
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey:
def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore
"""Create a cache key with params."""
key = LLMCacheKey(**kwargs)
key.set_serializer(self._cache_manager.serializer)
return key
def new_value(self, **kwargs) -> LLMCacheValue:
def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore
"""Create a cache value with params."""
value = LLMCacheValue(**kwargs)
value.set_serializer(self._cache_manager.serializer)
return value

View File

@@ -1,24 +1,31 @@
"""Cache manager."""
import logging
from abc import ABC, abstractmethod
from concurrent.futures import Executor
from typing import Optional, Type
from typing import Optional, Type, cast
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer
from dbgpt.core.interface.cache import K, V
from dbgpt.storage.cache.storage.base import CacheStorage
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
from .storage.base import CacheStorage
logger = logging.getLogger(__name__)
class CacheManager(BaseComponent, ABC):
"""The cache manager interface."""
name = ComponentType.MODEL_CACHE_MANAGER
def __init__(self, system_app: SystemApp | None = None):
"""Create cache manager."""
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
"""Initialize cache manager."""
self.system_app = system_app
@abstractmethod
@@ -28,7 +35,7 @@ class CacheManager(BaseComponent, ABC):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache"""
"""Set cache with key."""
@abstractmethod
async def get(
@@ -36,27 +43,30 @@ class CacheManager(BaseComponent, ABC):
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
"""Get cache with key"""
) -> Optional[CacheValue[V]]:
"""Retrieve cache with key."""
@property
@abstractmethod
def serializer(self) -> Serializer:
"""Get cache serializer"""
"""Return serializer to serialize/deserialize cache value."""
class LocalCacheManager(CacheManager):
"""Local cache manager."""
def __init__(
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
) -> None:
"""Create local cache manager."""
super().__init__(system_app)
self._serializer = serializer
self._storage = storage
@property
def executor(self) -> Executor:
"""Return executor to submit task"""
self._executor = self.system_app.get_component(
"""Return executor."""
return self.system_app.get_component( # type: ignore
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
@@ -66,6 +76,7 @@ class LocalCacheManager(CacheManager):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache with key."""
if self._storage.support_async():
await self._storage.aset(key, value, cache_config)
else:
@@ -78,7 +89,8 @@ class LocalCacheManager(CacheManager):
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
) -> Optional[CacheValue[V]]:
"""Retrieve cache with key."""
if self._storage.support_async():
item_bytes = await self._storage.aget(key, cache_config)
else:
@@ -87,30 +99,42 @@ class LocalCacheManager(CacheManager):
)
if not item_bytes:
return None
return self._serializer.deserialize(item_bytes.value_data, cls)
return cast(
CacheValue[V], self._serializer.deserialize(item_bytes.value_data, cls)
)
@property
def serializer(self) -> Serializer:
"""Return serializer to serialize/deserialize cache value."""
return self._serializer
def initialize_cache(
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
):
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
"""Initialize cache manager.
Args:
system_app (SystemApp): The system app.
storage_type (str): The storage type.
max_memory_mb (int): The max memory in MB.
persist_dir (str): The persist directory.
"""
from dbgpt.util.serialization.json_serialization import JsonSerializer
cache_storage = None
from .storage.base import MemoryCacheStorage
if storage_type == "disk":
try:
from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage
from .storage.disk.disk_storage import DiskCacheStorage
cache_storage = DiskCacheStorage(
cache_storage: CacheStorage = DiskCacheStorage(
persist_dir, mem_table_buffer_mb=max_memory_mb
)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error "
f"message: {str(e)}"
)
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
else:

View File

@@ -1,5 +1,6 @@
"""Operators for processing model outputs with caching support."""
import logging
from typing import AsyncIterator, Dict, List, Union
from typing import AsyncIterator, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput, ModelRequest
from dbgpt.core.awel import (
@@ -10,7 +11,9 @@ from dbgpt.core.awel import (
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from .manager import CacheManager
logger = logging.getLogger(__name__)
@@ -26,15 +29,17 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
**kwargs: Additional keyword arguments.
Methods:
streamify: Processes a stream of inputs with cache support, yielding model outputs.
streamify: Processes a stream of inputs with cache support, yielding model
outputs.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelStreamOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]:
async def streamify(self, input_value: ModelRequest):
"""Process inputs as a stream with cache support and yield model outputs.
Args:
@@ -45,10 +50,13 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
llm_cache_value = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
for out in llm_cache_value.get_value().output:
yield out
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
outputs = cast(List[ModelOutput], llm_cache_value.get_value().output)
for out in outputs:
yield cast(ModelOutput, out)
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
@@ -63,6 +71,7 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
@@ -78,14 +87,18 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
llm_cache_value = await self._client.get(llm_cache_key)
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
logger.info(f"llm_cache_value: {llm_cache_value}")
return llm_cache_value.get_value().output
return cast(ModelOutput, llm_cache_value.get_value().output)
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
"""
A branch operator that decides whether to use cached data or to process data using the model.
"""Branch operator for model processing with cache support.
A branch operator that decides whether to use cached data or to process data using
the model.
Args:
cache_manager (CacheManager): The cache manager for managing cache operations.
@@ -101,6 +114,7 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
cache_task_name: str,
**kwargs,
):
"""Create a new instance of ModelCacheBranchOperator."""
super().__init__(branches=None, **kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
@@ -110,10 +124,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""Defines branch logic based on cache availability.
"""Branch logic based on cache availability.
Defines branch logic based on cache availability.
Returns:
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping
branch functions to task names.
"""
async def check_cache_true(input_value: ModelRequest) -> bool:
@@ -124,12 +141,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: "
f"{cache_value}"
)
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
)
return True if cache_value else False
return bool(cache_value)
async def check_cache_false(input_value: ModelRequest):
# Inverse of check_cache_true
@@ -152,22 +170,25 @@ class ModelStreamSaveCacheOperator(
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelStreamSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def transform_stream(
self, input_value: AsyncIterator[ModelOutput]
) -> AsyncIterator[ModelOutput]:
"""Transforms the input stream by saving the outputs to cache.
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
"""Save the stream of model outputs to cache.
Transforms the input stream by saving the outputs to cache.
Args:
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model
outputs.
Returns:
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are
saved to cache.
"""
llm_cache_key: LLMCacheKey = None
llm_cache_key: Optional[LLMCacheKey] = None
outputs = []
async for out in input_value:
if not llm_cache_key:
@@ -190,12 +211,13 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Saves a single model output to cache and returns it.
"""Save model output to cache.
Args:
input_value (ModelOutput): The output from the model to be cached.
@@ -213,7 +235,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
"""Parses and extracts relevant fields from input to form a cache key dictionary.
"""Parse and extract relevant fields from input to form a cache key dictionary.
Args:
input_value (Dict): The input dictionary containing model and prompt parameters.

View File

@@ -0,0 +1 @@
"""Module for protocol."""

View File

@@ -0,0 +1 @@
"""Module for cache storage implementation."""

View File

@@ -1,3 +1,4 @@
"""Base cache storage class."""
import logging
from abc import ABC, abstractmethod
from collections import OrderedDict
@@ -22,8 +23,7 @@ logger = logging.getLogger(__name__)
@dataclass
class StorageItem:
"""
A class representing a storage item.
"""A class representing a storage item.
This class encapsulates data related to a storage item, such as its length,
the hash of the key, and the data for both the key and value.
@@ -44,6 +44,7 @@ class StorageItem:
def build_from(
key_hash: bytes, key_data: bytes, value_data: bytes
) -> "StorageItem":
"""Build a StorageItem from the provided key and value data."""
length = (
32
+ _get_object_bytes(key_hash)
@@ -56,6 +57,7 @@ class StorageItem:
@staticmethod
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
"""Build a StorageItem from the provided key and value."""
key_hash = key.get_hash_bytes()
key_data = key.serialize()
value_data = value.serialize()
@@ -105,6 +107,8 @@ class StorageItem:
class CacheStorage(ABC):
"""Base class for cache storage."""
@abstractmethod
def check_config(
self,
@@ -122,6 +126,7 @@ class CacheStorage(ABC):
"""
def support_async(self) -> bool:
"""Check whether the storage support async operation."""
return False
@abstractmethod
@@ -135,7 +140,8 @@ class CacheStorage(ABC):
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
Optional[StorageItem]: The storage item retrieved according to key. If
cache key not exist, return None.
"""
async def aget(
@@ -148,7 +154,8 @@ class CacheStorage(ABC):
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
Optional[StorageItem]: The storage item of bytes retrieved according to
key. If cache key not exist, return None.
"""
raise NotImplementedError
@@ -184,8 +191,11 @@ class CacheStorage(ABC):
class MemoryCacheStorage(CacheStorage):
"""A simple in-memory cache storage implementation."""
def __init__(self, max_memory_mb: int = 256):
self.cache = OrderedDict()
"""Create a new instance of MemoryCacheStorage."""
self.cache: OrderedDict = OrderedDict()
self.max_memory = max_memory_mb * 1024 * 1024
self.current_memory_usage = 0
@@ -194,6 +204,7 @@ class MemoryCacheStorage(CacheStorage):
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal."""
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
@@ -208,10 +219,11 @@ class MemoryCacheStorage(CacheStorage):
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key."""
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = hash(key)
item: StorageItem = self.cache.get(key_hash)
item: Optional[StorageItem] = self.cache.get(key_hash)
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
if not item:
@@ -226,6 +238,7 @@ class MemoryCacheStorage(CacheStorage):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
key_hash = hash(key)
item = StorageItem.build_from_kv(key, value)
# Calculate memory size of the new entry
@@ -242,6 +255,7 @@ class MemoryCacheStorage(CacheStorage):
def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
"""Check if the key exists in the cache."""
return self.get(key, cache_config) is not None
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):

View File

@@ -0,0 +1 @@
"""Disk cache storage implementation."""

View File

@@ -1,3 +1,7 @@
"""Disk storage for cache.
Implement the cache storage using rocksdb.
"""
import logging
from typing import Optional
@@ -11,14 +15,14 @@ from dbgpt.core.interface.cache import (
RetrievalPolicy,
V,
)
from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem
from ..base import CacheStorage, StorageItem
logger = logging.getLogger(__name__)
def db_options(
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
):
def db_options(mem_table_buffer_mb: int = 256, background_threads: int = 2):
"""Create rocksdb options."""
opt = Options()
# create table
opt.create_if_missing(True)
@@ -42,9 +46,10 @@ def db_options(
class DiskCacheStorage(CacheStorage):
def __init__(
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
) -> None:
"""Disk cache storage using rocksdb."""
def __init__(self, persist_dir: str, mem_table_buffer_mb: int = 256) -> None:
"""Create a new instance of DiskCacheStorage."""
super().__init__()
self.db: Rdict = Rdict(
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
@@ -55,6 +60,7 @@ class DiskCacheStorage(CacheStorage):
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal."""
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
@@ -69,6 +75,7 @@ class DiskCacheStorage(CacheStorage):
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key."""
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
@@ -86,6 +93,7 @@ class DiskCacheStorage(CacheStorage):
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
item = StorageItem.build_from_kv(key, value)
key_hash = item.key_hash
self.db[key_hash] = item.serialize()

View File

@@ -1,5 +1,3 @@
import pytest
from dbgpt.util.memory_utils import _get_object_bytes
from ..base import StorageItem