mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-13 05:01:25 +00:00
refactor: The first refactored version for sdk release (#907)
Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
12
dbgpt/storage/cache/__init__.py
vendored
Normal file
12
dbgpt/storage/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
|
||||
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
|
||||
from dbgpt.storage.cache.llm_cache import LLMCacheKey, LLMCacheValue, LLMCacheClient
|
||||
|
||||
__all__ = [
|
||||
"LLMCacheKey",
|
||||
"LLMCacheValue",
|
||||
"LLMCacheClient",
|
||||
"CacheManager",
|
||||
"initialize_cache",
|
||||
"MemoryCacheStorage",
|
||||
]
|
0
dbgpt/storage/cache/embedding_cache.py
vendored
Normal file
0
dbgpt/storage/cache/embedding_cache.py
vendored
Normal file
152
dbgpt/storage/cache/llm_cache.py
vendored
Normal file
152
dbgpt/storage/cache/llm_cache.py
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Optional, Dict, Any, Union, List
|
||||
from dataclasses import dataclass, asdict
|
||||
import hashlib
|
||||
|
||||
from dbgpt.core.interface.cache import (
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheClient,
|
||||
CacheConfig,
|
||||
)
|
||||
from dbgpt.storage.cache.manager import CacheManager
|
||||
from dbgpt.core import ModelOutput, Serializer
|
||||
from dbgpt.model.base import ModelType
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheKeyData:
|
||||
prompt: str
|
||||
model_name: str
|
||||
temperature: Optional[float] = 0.7
|
||||
max_new_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
model_type: Optional[str] = ModelType.HF
|
||||
|
||||
|
||||
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheValueData:
|
||||
output: CacheOutputType
|
||||
user: Optional[str] = None
|
||||
_is_list: Optional[bool] = False
|
||||
|
||||
@staticmethod
|
||||
def from_dict(**kwargs) -> "LLMCacheValueData":
|
||||
output = kwargs.get("output")
|
||||
if not output:
|
||||
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
||||
if isinstance(output, dict):
|
||||
output = ModelOutput(**output)
|
||||
elif isinstance(output, list):
|
||||
kwargs["_is_list"] = True
|
||||
output_list = []
|
||||
for out in output:
|
||||
if isinstance(out, dict):
|
||||
out = ModelOutput(**out)
|
||||
output_list.append(out)
|
||||
output = output_list
|
||||
kwargs["output"] = output
|
||||
return LLMCacheValueData(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
output = self.output
|
||||
is_list = False
|
||||
if isinstance(output, list):
|
||||
output_list = []
|
||||
is_list = True
|
||||
for out in output:
|
||||
output_list.append(out.to_dict())
|
||||
output = output_list
|
||||
else:
|
||||
output = output.to_dict()
|
||||
return {"output": output, "_is_list": is_list, "user": self.user}
|
||||
|
||||
@property
|
||||
def is_list(self) -> bool:
|
||||
return self._is_list
|
||||
|
||||
def __str__(self) -> str:
|
||||
if not isinstance(self.output, list):
|
||||
return f"user: {self.user}, output: {self.output}"
|
||||
else:
|
||||
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
|
||||
|
||||
|
||||
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._serializer = serializer
|
||||
self.config = LLMCacheKeyData(**kwargs)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
serialize_bytes = self.serialize()
|
||||
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, LLMCacheKey):
|
||||
return False
|
||||
return self.config == other.config
|
||||
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
serialize_bytes = self.serialize()
|
||||
return hashlib.sha256(serialize_bytes).digest()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return asdict(self.config)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def get_value(self) -> LLMCacheKeyData:
|
||||
return self.config
|
||||
|
||||
|
||||
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self._serializer = serializer
|
||||
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return self.value.to_dict()
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
return self._serializer.serialize(self)
|
||||
|
||||
def get_value(self) -> LLMCacheValueData:
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"vaue: {str(self.value)}"
|
||||
|
||||
|
||||
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||
def __init__(self, cache_manager: CacheManager) -> None:
|
||||
super().__init__()
|
||||
self._cache_manager: CacheManager = cache_manager
|
||||
|
||||
async def get(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[LLMCacheValue]:
|
||||
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: LLMCacheKey,
|
||||
value: LLMCacheValue,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
return await self._cache_manager.set(key, value, cache_config)
|
||||
|
||||
async def exists(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
return await self.get(key, cache_config) is not None
|
||||
|
||||
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
|
||||
|
||||
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)
|
125
dbgpt/storage/cache/manager.py
vendored
Normal file
125
dbgpt/storage/cache/manager.py
vendored
Normal file
@@ -0,0 +1,125 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Type
|
||||
import logging
|
||||
from concurrent.futures import Executor
|
||||
from dbgpt.storage.cache.storage.base import CacheStorage
|
||||
from dbgpt.core.interface.cache import K, V
|
||||
from dbgpt.core import (
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
Serializer,
|
||||
Serializable,
|
||||
)
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheManager(BaseComponent, ABC):
|
||||
name = ComponentType.MODEL_CACHE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache"""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
"""Get cache with key"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get cache serializer"""
|
||||
|
||||
|
||||
class LocalCacheManager(CacheManager):
|
||||
def __init__(
|
||||
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
|
||||
) -> None:
|
||||
super().__init__(system_app)
|
||||
self._serializer = serializer
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def executor(self) -> Executor:
|
||||
"""Return executor to submit task"""
|
||||
self._executor = self.system_app.get_component(
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
if self._storage.support_async():
|
||||
await self._storage.aset(key, value, cache_config)
|
||||
else:
|
||||
await blocking_func_to_async(
|
||||
self.executor, self._storage.set, key, value, cache_config
|
||||
)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
if self._storage.support_async():
|
||||
item_bytes = await self._storage.aget(key, cache_config)
|
||||
else:
|
||||
item_bytes = await blocking_func_to_async(
|
||||
self.executor, self._storage.get, key, cache_config
|
||||
)
|
||||
if not item_bytes:
|
||||
return None
|
||||
return self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
return self._serializer
|
||||
|
||||
|
||||
def initialize_cache(
|
||||
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
|
||||
):
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
|
||||
|
||||
cache_storage = None
|
||||
if storage_type == "disk":
|
||||
try:
|
||||
from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
|
||||
cache_storage = DiskCacheStorage(
|
||||
persist_dir, mem_table_buffer_mb=max_memory_mb
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
else:
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
system_app.register(
|
||||
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
|
||||
)
|
0
dbgpt/storage/cache/protocal/__init__.py
vendored
Normal file
0
dbgpt/storage/cache/protocal/__init__.py
vendored
Normal file
0
dbgpt/storage/cache/storage/__init__.py
vendored
Normal file
0
dbgpt/storage/cache/storage/__init__.py
vendored
Normal file
251
dbgpt/storage/cache/storage/base.py
vendored
Normal file
251
dbgpt/storage/cache/storage/base.py
vendored
Normal file
@@ -0,0 +1,251 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import OrderedDict
|
||||
import msgpack
|
||||
import logging
|
||||
|
||||
from dbgpt.core.interface.cache import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
RetrievalPolicy,
|
||||
CachePolicy,
|
||||
)
|
||||
from dbgpt.util.memory_utils import _get_object_bytes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageItem:
|
||||
"""
|
||||
A class representing a storage item.
|
||||
|
||||
This class encapsulates data related to a storage item, such as its length,
|
||||
the hash of the key, and the data for both the key and value.
|
||||
|
||||
Parameters:
|
||||
length (int): The bytes length of the storage item.
|
||||
key_hash (bytes): The hash value of the storage item's key.
|
||||
key_data (bytes): The data of the storage item's key, represented in bytes.
|
||||
value_data (bytes): The data of the storage item's value, also in bytes.
|
||||
"""
|
||||
|
||||
length: int # The bytes length of the storage item
|
||||
key_hash: bytes # The hash value of the storage item's key
|
||||
key_data: bytes # The data of the storage item's key
|
||||
value_data: bytes # The data of the storage item's value
|
||||
|
||||
@staticmethod
|
||||
def build_from(
|
||||
key_hash: bytes, key_data: bytes, value_data: bytes
|
||||
) -> "StorageItem":
|
||||
length = (
|
||||
32
|
||||
+ _get_object_bytes(key_hash)
|
||||
+ _get_object_bytes(key_data)
|
||||
+ _get_object_bytes(value_data)
|
||||
)
|
||||
return StorageItem(
|
||||
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
|
||||
key_hash = key.get_hash_bytes()
|
||||
key_data = key.serialize()
|
||||
value_data = value.serialize()
|
||||
return StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
def serialize(self) -> bytes:
|
||||
"""Serialize the StorageItem into a byte stream using MessagePack.
|
||||
|
||||
This method packs the object data into a dictionary, marking the
|
||||
key_data and value_data fields as raw binary data to avoid re-serialization.
|
||||
|
||||
Returns:
|
||||
bytes: The serialized bytes.
|
||||
"""
|
||||
obj = {
|
||||
"length": self.length,
|
||||
"key_hash": msgpack.ExtType(1, self.key_hash),
|
||||
"key_data": msgpack.ExtType(2, self.key_data),
|
||||
"value_data": msgpack.ExtType(3, self.value_data),
|
||||
}
|
||||
return msgpack.packb(obj)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(data: bytes) -> "StorageItem":
|
||||
"""Deserialize bytes back into a StorageItem using MessagePack.
|
||||
|
||||
This extracts the fields from the MessagePack dict back into
|
||||
a StorageItem object.
|
||||
|
||||
Args:
|
||||
data (bytes): Serialized bytes
|
||||
|
||||
Returns:
|
||||
StorageItem: Deserialized StorageItem object.
|
||||
"""
|
||||
obj = msgpack.unpackb(data)
|
||||
key_hash = obj["key_hash"].data
|
||||
key_data = obj["key_data"].data
|
||||
value_data = obj["value_data"].data
|
||||
|
||||
return StorageItem(
|
||||
length=obj["length"],
|
||||
key_hash=key_hash,
|
||||
key_data=key_data,
|
||||
value_data=value_data,
|
||||
)
|
||||
|
||||
|
||||
class CacheStorage(ABC):
|
||||
@abstractmethod
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal.
|
||||
|
||||
Args:
|
||||
cache_config (Optional[CacheConfig]): Cache config.
|
||||
raise_error (Optional[bool]): Whether raise error if illegal.
|
||||
|
||||
Returns:
|
||||
ValueError: Error when raise_error is True and config is illegal.
|
||||
"""
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
|
||||
async def aget(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
|
||||
async def aset(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key asynchronously.
|
||||
|
||||
Args:
|
||||
key (CacheKey[K]): The key to set to cache
|
||||
value (CacheValue[V]): The value to set to cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MemoryCacheStorage(CacheStorage):
|
||||
def __init__(self, max_memory_mb: int = 256):
|
||||
self.cache = OrderedDict()
|
||||
self.max_memory = max_memory_mb * 1024 * 1024
|
||||
self.current_memory_usage = 0
|
||||
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
# Exact match retrieval
|
||||
key_hash = hash(key)
|
||||
item: StorageItem = self.cache.get(key_hash)
|
||||
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
if not item:
|
||||
return None
|
||||
# Move the item to the end of the OrderedDict to signify recent use.
|
||||
self.cache.move_to_end(key_hash)
|
||||
return item
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
key_hash = hash(key)
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
# Calculate memory size of the new entry
|
||||
new_entry_size = _get_object_bytes(item)
|
||||
# Evict entries if necessary
|
||||
while self.current_memory_usage + new_entry_size > self.max_memory:
|
||||
self._apply_cache_policy(cache_config)
|
||||
|
||||
# Store the item in the cache.
|
||||
self.cache[key_hash] = item
|
||||
self.current_memory_usage += new_entry_size
|
||||
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
return self.get(key, cache_config) is not None
|
||||
|
||||
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
|
||||
# Remove the oldest/newest item based on the cache policy.
|
||||
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
|
||||
self.cache.popitem(last=False)
|
||||
else: # Default is LRU
|
||||
self.cache.popitem(last=True)
|
0
dbgpt/storage/cache/storage/disk/__init__.py
vendored
Normal file
0
dbgpt/storage/cache/storage/disk/__init__.py
vendored
Normal file
91
dbgpt/storage/cache/storage/disk/disk_storage.py
vendored
Normal file
91
dbgpt/storage/cache/storage/disk/disk_storage.py
vendored
Normal file
@@ -0,0 +1,91 @@
|
||||
from typing import Optional
|
||||
import logging
|
||||
from rocksdict import Rdict, Options
|
||||
|
||||
from dbgpt.core.interface.cache import (
|
||||
K,
|
||||
V,
|
||||
CacheKey,
|
||||
CacheValue,
|
||||
CacheConfig,
|
||||
RetrievalPolicy,
|
||||
)
|
||||
from dbgpt.storage.cache.storage.base import StorageItem, CacheStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def db_options(
|
||||
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
|
||||
):
|
||||
opt = Options()
|
||||
# create table
|
||||
opt.create_if_missing(True)
|
||||
# config to more jobs, default 2
|
||||
opt.set_max_background_jobs(background_threads)
|
||||
# configure mem-table to a large value
|
||||
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
|
||||
# opt.set_write_buffer_size(1024)
|
||||
# opt.set_level_zero_file_num_compaction_trigger(4)
|
||||
# configure l0 and l1 size, let them have the same size (1 GB)
|
||||
# opt.set_max_bytes_for_level_base(0x40000000)
|
||||
# 256 MB file size
|
||||
# opt.set_target_file_size_base(0x10000000)
|
||||
# use a smaller compaction multiplier
|
||||
# opt.set_max_bytes_for_level_multiplier(4.0)
|
||||
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
|
||||
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
|
||||
# set to plain-table
|
||||
# opt.set_plain_table_factory(PlainTableFactoryOptions())
|
||||
return opt
|
||||
|
||||
|
||||
class DiskCacheStorage(CacheStorage):
|
||||
def __init__(
|
||||
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.db: Rdict = Rdict(
|
||||
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
|
||||
)
|
||||
|
||||
def check_config(
|
||||
self,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
):
|
||||
if raise_error:
|
||||
raise ValueError(
|
||||
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
|
||||
# Exact match retrieval
|
||||
key_hash = key.get_hash_bytes()
|
||||
item_bytes = self.db.get(key_hash)
|
||||
if not item_bytes:
|
||||
return None
|
||||
item = StorageItem.deserialize(item_bytes)
|
||||
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
|
||||
return item
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: CacheKey[K],
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
key_hash = item.key_hash
|
||||
self.db[key_hash] = item.serialize()
|
||||
logger.debug(f"Save file cache, key: {key}, value: {value}")
|
0
dbgpt/storage/cache/storage/tests/__init__.py
vendored
Normal file
0
dbgpt/storage/cache/storage/tests/__init__.py
vendored
Normal file
53
dbgpt/storage/cache/storage/tests/test_storage.py
vendored
Normal file
53
dbgpt/storage/cache/storage/tests/test_storage.py
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
import pytest
|
||||
from ..base import StorageItem
|
||||
from dbgpt.util.memory_utils import _get_object_bytes
|
||||
|
||||
|
||||
def test_build_from():
|
||||
key_hash = b"key_hash"
|
||||
key_data = b"key_data"
|
||||
value_data = b"value_data"
|
||||
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
assert item.key_hash == key_hash
|
||||
assert item.key_data == key_data
|
||||
assert item.value_data == value_data
|
||||
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
|
||||
key_data
|
||||
) + _get_object_bytes(value_data)
|
||||
|
||||
|
||||
def test_build_from_kv():
|
||||
class MockCacheKey:
|
||||
def get_hash_bytes(self):
|
||||
return b"key_hash"
|
||||
|
||||
def serialize(self):
|
||||
return b"key_data"
|
||||
|
||||
class MockCacheValue:
|
||||
def serialize(self):
|
||||
return b"value_data"
|
||||
|
||||
key = MockCacheKey()
|
||||
value = MockCacheValue()
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
|
||||
assert item.key_hash == key.get_hash_bytes()
|
||||
assert item.key_data == key.serialize()
|
||||
assert item.value_data == value.serialize()
|
||||
|
||||
|
||||
def test_serialize_deserialize():
|
||||
key_hash = b"key_hash"
|
||||
key_data = b"key_data"
|
||||
value_data = b"value_data"
|
||||
item = StorageItem.build_from(key_hash, key_data, value_data)
|
||||
|
||||
serialized = item.serialize()
|
||||
deserialized = StorageItem.deserialize(serialized)
|
||||
|
||||
assert deserialized.key_hash == item.key_hash
|
||||
assert deserialized.key_data == item.key_data
|
||||
assert deserialized.value_data == item.value_data
|
||||
assert deserialized.length == item.length
|
Reference in New Issue
Block a user