refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

12
dbgpt/storage/cache/__init__.py vendored Normal file
View File

@@ -0,0 +1,12 @@
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
from dbgpt.storage.cache.llm_cache import LLMCacheKey, LLMCacheValue, LLMCacheClient
__all__ = [
"LLMCacheKey",
"LLMCacheValue",
"LLMCacheClient",
"CacheManager",
"initialize_cache",
"MemoryCacheStorage",
]

View File

152
dbgpt/storage/cache/llm_cache.py vendored Normal file
View File

@@ -0,0 +1,152 @@
from typing import Optional, Dict, Any, Union, List
from dataclasses import dataclass, asdict
import hashlib
from dbgpt.core.interface.cache import (
CacheKey,
CacheValue,
CacheClient,
CacheConfig,
)
from dbgpt.storage.cache.manager import CacheManager
from dbgpt.core import ModelOutput, Serializer
from dbgpt.model.base import ModelType
@dataclass
class LLMCacheKeyData:
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
model_type: Optional[str] = ModelType.HF
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
output: CacheOutputType
user: Optional[str] = None
_is_list: Optional[bool] = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
if isinstance(output, dict):
output = ModelOutput(**output)
elif isinstance(output, list):
kwargs["_is_list"] = True
output_list = []
for out in output:
if isinstance(out, dict):
out = ModelOutput(**out)
output_list.append(out)
output = output_list
kwargs["output"] = output
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
output = self.output
is_list = False
if isinstance(output, list):
output_list = []
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list
else:
output = output.to_dict()
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
return self._is_list
def __str__(self) -> str:
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
super().__init__()
self._serializer = serializer
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
return asdict(self.config)
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheKeyData:
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
def __init__(self, serializer: Serializer = None, **kwargs) -> None:
super().__init__()
self._serializer = serializer
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
return self.value.to_dict()
def serialize(self) -> bytes:
return self._serializer.serialize(self)
def get_value(self) -> LLMCacheValueData:
return self.value
def __str__(self) -> str:
return f"vaue: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
def __init__(self, cache_manager: CacheManager) -> None:
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> Optional[LLMCacheValue]:
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
async def set(
self,
key: LLMCacheKey,
value: LLMCacheValue,
cache_config: Optional[CacheConfig] = None,
) -> None:
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> bool:
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey:
return LLMCacheKey(serializer=self._cache_manager.serializer, **kwargs)
def new_value(self, **kwargs) -> LLMCacheValue:
return LLMCacheValue(serializer=self._cache_manager.serializer, **kwargs)

125
dbgpt/storage/cache/manager.py vendored Normal file
View File

@@ -0,0 +1,125 @@
from abc import ABC, abstractmethod
from typing import Optional, Type
import logging
from concurrent.futures import Executor
from dbgpt.storage.cache.storage.base import CacheStorage
from dbgpt.core.interface.cache import K, V
from dbgpt.core import (
CacheKey,
CacheValue,
CacheConfig,
Serializer,
Serializable,
)
from dbgpt.component import BaseComponent, ComponentType, SystemApp
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
logger = logging.getLogger(__name__)
class CacheManager(BaseComponent, ABC):
name = ComponentType.MODEL_CACHE_MANAGER
def __init__(self, system_app: SystemApp | None = None):
super().__init__(system_app)
def init_app(self, system_app: SystemApp):
self.system_app = system_app
@abstractmethod
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
"""Set cache"""
@abstractmethod
async def get(
self,
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
"""Get cache with key"""
@property
@abstractmethod
def serializer(self) -> Serializer:
"""Get cache serializer"""
class LocalCacheManager(CacheManager):
def __init__(
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
) -> None:
super().__init__(system_app)
self._serializer = serializer
self._storage = storage
@property
def executor(self) -> Executor:
"""Return executor to submit task"""
self._executor = self.system_app.get_component(
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
).create()
async def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
):
if self._storage.support_async():
await self._storage.aset(key, value, cache_config)
else:
await blocking_func_to_async(
self.executor, self._storage.set, key, value, cache_config
)
async def get(
self,
key: CacheKey[K],
cls: Type[Serializable],
cache_config: Optional[CacheConfig] = None,
) -> CacheValue[V]:
if self._storage.support_async():
item_bytes = await self._storage.aget(key, cache_config)
else:
item_bytes = await blocking_func_to_async(
self.executor, self._storage.get, key, cache_config
)
if not item_bytes:
return None
return self._serializer.deserialize(item_bytes.value_data, cls)
@property
def serializer(self) -> Serializer:
return self._serializer
def initialize_cache(
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
):
from dbgpt.util.serialization.json_serialization import JsonSerializer
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
cache_storage = None
if storage_type == "disk":
try:
from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage
cache_storage = DiskCacheStorage(
persist_dir, mem_table_buffer_mb=max_memory_mb
)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
)
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
else:
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
system_app.register(
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
)

View File

View File

251
dbgpt/storage/cache/storage/base.py vendored Normal file
View File

@@ -0,0 +1,251 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
from collections import OrderedDict
import msgpack
import logging
from dbgpt.core.interface.cache import (
K,
V,
CacheKey,
CacheValue,
CacheConfig,
RetrievalPolicy,
CachePolicy,
)
from dbgpt.util.memory_utils import _get_object_bytes
logger = logging.getLogger(__name__)
@dataclass
class StorageItem:
"""
A class representing a storage item.
This class encapsulates data related to a storage item, such as its length,
the hash of the key, and the data for both the key and value.
Parameters:
length (int): The bytes length of the storage item.
key_hash (bytes): The hash value of the storage item's key.
key_data (bytes): The data of the storage item's key, represented in bytes.
value_data (bytes): The data of the storage item's value, also in bytes.
"""
length: int # The bytes length of the storage item
key_hash: bytes # The hash value of the storage item's key
key_data: bytes # The data of the storage item's key
value_data: bytes # The data of the storage item's value
@staticmethod
def build_from(
key_hash: bytes, key_data: bytes, value_data: bytes
) -> "StorageItem":
length = (
32
+ _get_object_bytes(key_hash)
+ _get_object_bytes(key_data)
+ _get_object_bytes(value_data)
)
return StorageItem(
length=length, key_hash=key_hash, key_data=key_data, value_data=value_data
)
@staticmethod
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
key_hash = key.get_hash_bytes()
key_data = key.serialize()
value_data = value.serialize()
return StorageItem.build_from(key_hash, key_data, value_data)
def serialize(self) -> bytes:
"""Serialize the StorageItem into a byte stream using MessagePack.
This method packs the object data into a dictionary, marking the
key_data and value_data fields as raw binary data to avoid re-serialization.
Returns:
bytes: The serialized bytes.
"""
obj = {
"length": self.length,
"key_hash": msgpack.ExtType(1, self.key_hash),
"key_data": msgpack.ExtType(2, self.key_data),
"value_data": msgpack.ExtType(3, self.value_data),
}
return msgpack.packb(obj)
@staticmethod
def deserialize(data: bytes) -> "StorageItem":
"""Deserialize bytes back into a StorageItem using MessagePack.
This extracts the fields from the MessagePack dict back into
a StorageItem object.
Args:
data (bytes): Serialized bytes
Returns:
StorageItem: Deserialized StorageItem object.
"""
obj = msgpack.unpackb(data)
key_hash = obj["key_hash"].data
key_data = obj["key_data"].data
value_data = obj["value_data"].data
return StorageItem(
length=obj["length"],
key_hash=key_hash,
key_data=key_data,
value_data=value_data,
)
class CacheStorage(ABC):
@abstractmethod
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
"""Check whether the CacheConfig is legal.
Args:
cache_config (Optional[CacheConfig]): Cache config.
raise_error (Optional[bool]): Whether raise error if illegal.
Returns:
ValueError: Error when raise_error is True and config is illegal.
"""
def support_async(self) -> bool:
return False
@abstractmethod
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
"""
async def aget(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
"""Retrieve a storage item from the cache using the provided key asynchronously.
Args:
key (CacheKey[K]): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
"""
raise NotImplementedError
@abstractmethod
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key asynchronously.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
async def aset(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key asynchronously.
Args:
key (CacheKey[K]): The key to set to cache
value (CacheValue[V]): The value to set to cache
cache_config (Optional[CacheConfig]): Cache config
"""
raise NotImplementedError
class MemoryCacheStorage(CacheStorage):
def __init__(self, max_memory_mb: int = 256):
self.cache = OrderedDict()
self.max_memory = max_memory_mb * 1024 * 1024
self.current_memory_usage = 0
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
):
if raise_error:
raise ValueError(
"MemoryCacheStorage only supports 'EXACT_MATCH' retrieval policy"
)
return False
return True
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = hash(key)
item: StorageItem = self.cache.get(key_hash)
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
if not item:
return None
# Move the item to the end of the OrderedDict to signify recent use.
self.cache.move_to_end(key_hash)
return item
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
key_hash = hash(key)
item = StorageItem.build_from_kv(key, value)
# Calculate memory size of the new entry
new_entry_size = _get_object_bytes(item)
# Evict entries if necessary
while self.current_memory_usage + new_entry_size > self.max_memory:
self._apply_cache_policy(cache_config)
# Store the item in the cache.
self.cache[key_hash] = item
self.current_memory_usage += new_entry_size
logger.debug(f"MemoryCacheStorage set key {key}, hash {key_hash}, item: {item}")
def exists(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> bool:
return self.get(key, cache_config) is not None
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
# Remove the oldest/newest item based on the cache policy.
if cache_config and cache_config.cache_policy == CachePolicy.FIFO:
self.cache.popitem(last=False)
else: # Default is LRU
self.cache.popitem(last=True)

View File

View File

@@ -0,0 +1,91 @@
from typing import Optional
import logging
from rocksdict import Rdict, Options
from dbgpt.core.interface.cache import (
K,
V,
CacheKey,
CacheValue,
CacheConfig,
RetrievalPolicy,
)
from dbgpt.storage.cache.storage.base import StorageItem, CacheStorage
logger = logging.getLogger(__name__)
def db_options(
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
):
opt = Options()
# create table
opt.create_if_missing(True)
# config to more jobs, default 2
opt.set_max_background_jobs(background_threads)
# configure mem-table to a large value
opt.set_write_buffer_size(mem_table_buffer_mb * 1024 * 1024)
# opt.set_write_buffer_size(1024)
# opt.set_level_zero_file_num_compaction_trigger(4)
# configure l0 and l1 size, let them have the same size (1 GB)
# opt.set_max_bytes_for_level_base(0x40000000)
# 256 MB file size
# opt.set_target_file_size_base(0x10000000)
# use a smaller compaction multiplier
# opt.set_max_bytes_for_level_multiplier(4.0)
# use 8-byte prefix (2 ^ 64 is far enough for transaction counts)
# opt.set_prefix_extractor(SliceTransform.create_max_len_prefix(8))
# set to plain-table
# opt.set_plain_table_factory(PlainTableFactoryOptions())
return opt
class DiskCacheStorage(CacheStorage):
def __init__(
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
) -> None:
super().__init__()
self.db: Rdict = Rdict(
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
)
def check_config(
self,
cache_config: Optional[CacheConfig] = None,
raise_error: Optional[bool] = True,
) -> bool:
if (
cache_config
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
):
if raise_error:
raise ValueError(
"DiskCacheStorage only supports 'EXACT_MATCH' retrieval policy"
)
return False
return True
def get(
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
) -> Optional[StorageItem]:
self.check_config(cache_config, raise_error=True)
# Exact match retrieval
key_hash = key.get_hash_bytes()
item_bytes = self.db.get(key_hash)
if not item_bytes:
return None
item = StorageItem.deserialize(item_bytes)
logger.debug(f"Read file cache, key: {key}, storage item: {item}")
return item
def set(
self,
key: CacheKey[K],
value: CacheValue[V],
cache_config: Optional[CacheConfig] = None,
) -> None:
item = StorageItem.build_from_kv(key, value)
key_hash = item.key_hash
self.db[key_hash] = item.serialize()
logger.debug(f"Save file cache, key: {key}, value: {value}")

View File

View File

@@ -0,0 +1,53 @@
import pytest
from ..base import StorageItem
from dbgpt.util.memory_utils import _get_object_bytes
def test_build_from():
key_hash = b"key_hash"
key_data = b"key_data"
value_data = b"value_data"
item = StorageItem.build_from(key_hash, key_data, value_data)
assert item.key_hash == key_hash
assert item.key_data == key_data
assert item.value_data == value_data
assert item.length == 32 + _get_object_bytes(key_hash) + _get_object_bytes(
key_data
) + _get_object_bytes(value_data)
def test_build_from_kv():
class MockCacheKey:
def get_hash_bytes(self):
return b"key_hash"
def serialize(self):
return b"key_data"
class MockCacheValue:
def serialize(self):
return b"value_data"
key = MockCacheKey()
value = MockCacheValue()
item = StorageItem.build_from_kv(key, value)
assert item.key_hash == key.get_hash_bytes()
assert item.key_data == key.serialize()
assert item.value_data == value.serialize()
def test_serialize_deserialize():
key_hash = b"key_hash"
key_data = b"key_data"
value_data = b"value_data"
item = StorageItem.build_from(key_hash, key_data, value_data)
serialized = item.serialize()
deserialized = StorageItem.deserialize(serialized)
assert deserialized.key_hash == item.key_hash
assert deserialized.key_data == item.key_data
assert deserialized.value_data == item.value_data
assert deserialized.length == item.length