mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-12 12:37:14 +00:00
chore: Add pylint for storage (#1298)
This commit is contained in:
7
dbgpt/storage/cache/__init__.py
vendored
7
dbgpt/storage/cache/__init__.py
vendored
@@ -1,6 +1,7 @@
|
||||
from dbgpt.storage.cache.llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
from dbgpt.storage.cache.manager import CacheManager, initialize_cache
|
||||
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
|
||||
"""Module for cache storage."""
|
||||
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue # noqa: F401
|
||||
from .manager import CacheManager, initialize_cache # noqa: F401
|
||||
from .storage.base import MemoryCacheStorage # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"LLMCacheKey",
|
||||
|
1
dbgpt/storage/cache/embedding_cache.py
vendored
1
dbgpt/storage/cache/embedding_cache.py
vendored
@@ -0,0 +1 @@
|
||||
"""Embeddings cache."""
|
||||
|
83
dbgpt/storage/cache/llm_cache.py
vendored
83
dbgpt/storage/cache/llm_cache.py
vendored
@@ -1,21 +1,26 @@
|
||||
"""Cache client for LLM."""
|
||||
|
||||
import hashlib
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import ModelOutput, Serializer
|
||||
from dbgpt.core import ModelOutput
|
||||
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.storage.cache.manager import CacheManager
|
||||
|
||||
from .manager import CacheManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMCacheKeyData:
|
||||
"""Cache key data for LLM."""
|
||||
|
||||
prompt: str
|
||||
model_name: str
|
||||
temperature: Optional[float] = 0.7
|
||||
max_new_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = 1.0
|
||||
model_type: Optional[str] = ModelType.HF
|
||||
# See dbgpt.model.base.ModelType
|
||||
model_type: Optional[str] = "huggingface"
|
||||
|
||||
|
||||
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
@@ -23,12 +28,15 @@ CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
||||
|
||||
@dataclass
|
||||
class LLMCacheValueData:
|
||||
"""Cache value data for LLM."""
|
||||
|
||||
output: CacheOutputType
|
||||
user: Optional[str] = None
|
||||
_is_list: Optional[bool] = False
|
||||
_is_list: bool = False
|
||||
|
||||
@staticmethod
|
||||
def from_dict(**kwargs) -> "LLMCacheValueData":
|
||||
"""Create LLMCacheValueData object from dict."""
|
||||
output = kwargs.get("output")
|
||||
if not output:
|
||||
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
||||
@@ -46,6 +54,7 @@ class LLMCacheValueData:
|
||||
return LLMCacheValueData(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
output = self.output
|
||||
is_list = False
|
||||
if isinstance(output, list):
|
||||
@@ -53,16 +62,18 @@ class LLMCacheValueData:
|
||||
is_list = True
|
||||
for out in output:
|
||||
output_list.append(out.to_dict())
|
||||
output = output_list
|
||||
output = output_list # type: ignore
|
||||
else:
|
||||
output = output.to_dict()
|
||||
output = output.to_dict() # type: ignore
|
||||
return {"output": output, "_is_list": is_list, "user": self.user}
|
||||
|
||||
@property
|
||||
def is_list(self) -> bool:
|
||||
"""Return whether the output is a list."""
|
||||
return self._is_list
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string representation."""
|
||||
if not isinstance(self.output, list):
|
||||
return f"user: {self.user}, output: {self.output}"
|
||||
else:
|
||||
@@ -70,74 +81,116 @@ class LLMCacheValueData:
|
||||
|
||||
|
||||
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
||||
"""Cache key for LLM."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Create a new instance of LLMCacheKey."""
|
||||
super().__init__()
|
||||
self.config = LLMCacheKeyData(**kwargs)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Return the hash value of the object."""
|
||||
serialize_bytes = self.serialize()
|
||||
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Check equality with another key."""
|
||||
if not isinstance(other, LLMCacheKey):
|
||||
return False
|
||||
return self.config == other.config
|
||||
|
||||
def get_hash_bytes(self) -> bytes:
|
||||
"""Return the byte array of hash value.
|
||||
|
||||
Returns:
|
||||
bytes: The byte array of hash value.
|
||||
"""
|
||||
serialize_bytes = self.serialize()
|
||||
return hashlib.sha256(serialize_bytes).digest()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return asdict(self.config)
|
||||
|
||||
def get_value(self) -> LLMCacheKeyData:
|
||||
"""Return the real object of current cache key."""
|
||||
return self.config
|
||||
|
||||
|
||||
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
||||
"""Cache value for LLM."""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""Create a new instance of LLMCacheValue."""
|
||||
super().__init__()
|
||||
self.value = LLMCacheValueData.from_dict(**kwargs)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dict."""
|
||||
return self.value.to_dict()
|
||||
|
||||
def get_value(self) -> LLMCacheValueData:
|
||||
"""Return the underlying real value."""
|
||||
return self.value
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return string representation."""
|
||||
return f"value: {str(self.value)}"
|
||||
|
||||
|
||||
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
||||
"""Cache client for LLM."""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager) -> None:
|
||||
"""Create a new instance of LLMCacheClient."""
|
||||
super().__init__()
|
||||
self._cache_manager: CacheManager = cache_manager
|
||||
|
||||
async def get(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
self,
|
||||
key: LLMCacheKey, # type: ignore
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> Optional[LLMCacheValue]:
|
||||
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
|
||||
"""Retrieve a value from the cache using the provided key.
|
||||
|
||||
Args:
|
||||
key (LLMCacheKey): The key to get cache
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[LLMCacheValue]: The value retrieved according to key. If cache key
|
||||
not exist, return None.
|
||||
"""
|
||||
return cast(
|
||||
LLMCacheValue,
|
||||
await self._cache_manager.get(key, LLMCacheValue, cache_config),
|
||||
)
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: LLMCacheKey,
|
||||
value: LLMCacheValue,
|
||||
key: LLMCacheKey, # type: ignore
|
||||
value: LLMCacheValue, # type: ignore
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key."""
|
||||
return await self._cache_manager.set(key, value, cache_config)
|
||||
|
||||
async def exists(
|
||||
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
|
||||
self,
|
||||
key: LLMCacheKey, # type: ignore
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> bool:
|
||||
"""Check if a key exists in the cache."""
|
||||
return await self.get(key, cache_config) is not None
|
||||
|
||||
def new_key(self, **kwargs) -> LLMCacheKey:
|
||||
def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore
|
||||
"""Create a cache key with params."""
|
||||
key = LLMCacheKey(**kwargs)
|
||||
key.set_serializer(self._cache_manager.serializer)
|
||||
return key
|
||||
|
||||
def new_value(self, **kwargs) -> LLMCacheValue:
|
||||
def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore
|
||||
"""Create a cache value with params."""
|
||||
value = LLMCacheValue(**kwargs)
|
||||
value.set_serializer(self._cache_manager.serializer)
|
||||
return value
|
||||
|
54
dbgpt/storage/cache/manager.py
vendored
54
dbgpt/storage/cache/manager.py
vendored
@@ -1,24 +1,31 @@
|
||||
"""Cache manager."""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Executor
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, cast
|
||||
|
||||
from dbgpt.component import BaseComponent, ComponentType, SystemApp
|
||||
from dbgpt.core import CacheConfig, CacheKey, CacheValue, Serializable, Serializer
|
||||
from dbgpt.core.interface.cache import K, V
|
||||
from dbgpt.storage.cache.storage.base import CacheStorage
|
||||
from dbgpt.util.executor_utils import ExecutorFactory, blocking_func_to_async
|
||||
|
||||
from .storage.base import CacheStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheManager(BaseComponent, ABC):
|
||||
"""The cache manager interface."""
|
||||
|
||||
name = ComponentType.MODEL_CACHE_MANAGER
|
||||
|
||||
def __init__(self, system_app: SystemApp | None = None):
|
||||
"""Create cache manager."""
|
||||
super().__init__(system_app)
|
||||
|
||||
def init_app(self, system_app: SystemApp):
|
||||
"""Initialize cache manager."""
|
||||
self.system_app = system_app
|
||||
|
||||
@abstractmethod
|
||||
@@ -28,7 +35,7 @@ class CacheManager(BaseComponent, ABC):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache"""
|
||||
"""Set cache with key."""
|
||||
|
||||
@abstractmethod
|
||||
async def get(
|
||||
@@ -36,27 +43,30 @@ class CacheManager(BaseComponent, ABC):
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
"""Get cache with key"""
|
||||
) -> Optional[CacheValue[V]]:
|
||||
"""Retrieve cache with key."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def serializer(self) -> Serializer:
|
||||
"""Get cache serializer"""
|
||||
"""Return serializer to serialize/deserialize cache value."""
|
||||
|
||||
|
||||
class LocalCacheManager(CacheManager):
|
||||
"""Local cache manager."""
|
||||
|
||||
def __init__(
|
||||
self, system_app: SystemApp, serializer: Serializer, storage: CacheStorage
|
||||
) -> None:
|
||||
"""Create local cache manager."""
|
||||
super().__init__(system_app)
|
||||
self._serializer = serializer
|
||||
self._storage = storage
|
||||
|
||||
@property
|
||||
def executor(self) -> Executor:
|
||||
"""Return executor to submit task"""
|
||||
self._executor = self.system_app.get_component(
|
||||
"""Return executor."""
|
||||
return self.system_app.get_component( # type: ignore
|
||||
ComponentType.EXECUTOR_DEFAULT, ExecutorFactory
|
||||
).create()
|
||||
|
||||
@@ -66,6 +76,7 @@ class LocalCacheManager(CacheManager):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
):
|
||||
"""Set cache with key."""
|
||||
if self._storage.support_async():
|
||||
await self._storage.aset(key, value, cache_config)
|
||||
else:
|
||||
@@ -78,7 +89,8 @@ class LocalCacheManager(CacheManager):
|
||||
key: CacheKey[K],
|
||||
cls: Type[Serializable],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> CacheValue[V]:
|
||||
) -> Optional[CacheValue[V]]:
|
||||
"""Retrieve cache with key."""
|
||||
if self._storage.support_async():
|
||||
item_bytes = await self._storage.aget(key, cache_config)
|
||||
else:
|
||||
@@ -87,30 +99,42 @@ class LocalCacheManager(CacheManager):
|
||||
)
|
||||
if not item_bytes:
|
||||
return None
|
||||
return self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
return cast(
|
||||
CacheValue[V], self._serializer.deserialize(item_bytes.value_data, cls)
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> Serializer:
|
||||
"""Return serializer to serialize/deserialize cache value."""
|
||||
return self._serializer
|
||||
|
||||
|
||||
def initialize_cache(
|
||||
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
|
||||
):
|
||||
from dbgpt.storage.cache.storage.base import MemoryCacheStorage
|
||||
"""Initialize cache manager.
|
||||
|
||||
Args:
|
||||
system_app (SystemApp): The system app.
|
||||
storage_type (str): The storage type.
|
||||
max_memory_mb (int): The max memory in MB.
|
||||
persist_dir (str): The persist directory.
|
||||
"""
|
||||
from dbgpt.util.serialization.json_serialization import JsonSerializer
|
||||
|
||||
cache_storage = None
|
||||
from .storage.base import MemoryCacheStorage
|
||||
|
||||
if storage_type == "disk":
|
||||
try:
|
||||
from dbgpt.storage.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
from .storage.disk.disk_storage import DiskCacheStorage
|
||||
|
||||
cache_storage = DiskCacheStorage(
|
||||
cache_storage: CacheStorage = DiskCacheStorage(
|
||||
persist_dir, mem_table_buffer_mb=max_memory_mb
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error "
|
||||
f"message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
else:
|
||||
|
@@ -1,5 +1,6 @@
|
||||
"""Operators for processing model outputs with caching support."""
|
||||
import logging
|
||||
from typing import AsyncIterator, Dict, List, Union
|
||||
from typing import AsyncIterator, Dict, List, Optional, Union, cast
|
||||
|
||||
from dbgpt.core import ModelOutput, ModelRequest
|
||||
from dbgpt.core.awel import (
|
||||
@@ -10,7 +11,9 @@ from dbgpt.core.awel import (
|
||||
StreamifyAbsOperator,
|
||||
TransformStreamAbsOperator,
|
||||
)
|
||||
from dbgpt.storage.cache import CacheManager, LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
|
||||
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
|
||||
from .manager import CacheManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,15 +29,17 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Methods:
|
||||
streamify: Processes a stream of inputs with cache support, yielding model outputs.
|
||||
streamify: Processes a stream of inputs with cache support, yielding model
|
||||
outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
"""Create a new instance of CachedModelStreamOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
|
||||
async def streamify(self, input_value: ModelRequest) -> AsyncIterator[ModelOutput]:
|
||||
async def streamify(self, input_value: ModelRequest):
|
||||
"""Process inputs as a stream with cache support and yield model outputs.
|
||||
|
||||
Args:
|
||||
@@ -45,10 +50,13 @@ class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput])
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
llm_cache_value = await self._client.get(llm_cache_key)
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
for out in llm_cache_value.get_value().output:
|
||||
yield out
|
||||
if not llm_cache_value:
|
||||
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
|
||||
outputs = cast(List[ModelOutput], llm_cache_value.get_value().output)
|
||||
for out in outputs:
|
||||
yield cast(ModelOutput, out)
|
||||
|
||||
|
||||
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
@@ -63,6 +71,7 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
|
||||
"""Create a new instance of CachedModelOperator."""
|
||||
super().__init__(**kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
@@ -78,14 +87,18 @@ class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
|
||||
"""
|
||||
cache_dict = _parse_cache_key_dict(input_value)
|
||||
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
llm_cache_value: LLMCacheValue = await self._client.get(llm_cache_key)
|
||||
llm_cache_value = await self._client.get(llm_cache_key)
|
||||
if not llm_cache_value:
|
||||
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
|
||||
logger.info(f"llm_cache_value: {llm_cache_value}")
|
||||
return llm_cache_value.get_value().output
|
||||
return cast(ModelOutput, llm_cache_value.get_value().output)
|
||||
|
||||
|
||||
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
"""
|
||||
A branch operator that decides whether to use cached data or to process data using the model.
|
||||
"""Branch operator for model processing with cache support.
|
||||
|
||||
A branch operator that decides whether to use cached data or to process data using
|
||||
the model.
|
||||
|
||||
Args:
|
||||
cache_manager (CacheManager): The cache manager for managing cache operations.
|
||||
@@ -101,6 +114,7 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
cache_task_name: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a new instance of ModelCacheBranchOperator."""
|
||||
super().__init__(branches=None, **kwargs)
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
@@ -110,10 +124,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
async def branches(
|
||||
self,
|
||||
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
|
||||
"""Defines branch logic based on cache availability.
|
||||
"""Branch logic based on cache availability.
|
||||
|
||||
Defines branch logic based on cache availability.
|
||||
|
||||
Returns:
|
||||
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping branch functions to task names.
|
||||
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping
|
||||
branch functions to task names.
|
||||
"""
|
||||
|
||||
async def check_cache_true(input_value: ModelRequest) -> bool:
|
||||
@@ -124,12 +141,13 @@ class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
|
||||
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
|
||||
cache_value = await self._client.get(cache_key)
|
||||
logger.debug(
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: {cache_value}"
|
||||
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: "
|
||||
f"{cache_value}"
|
||||
)
|
||||
await self.current_dag_context.save_to_share_data(
|
||||
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
|
||||
)
|
||||
return True if cache_value else False
|
||||
return bool(cache_value)
|
||||
|
||||
async def check_cache_false(input_value: ModelRequest):
|
||||
# Inverse of check_cache_true
|
||||
@@ -152,22 +170,25 @@ class ModelStreamSaveCacheOperator(
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
"""Create a new instance of ModelStreamSaveCacheOperator."""
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def transform_stream(
|
||||
self, input_value: AsyncIterator[ModelOutput]
|
||||
) -> AsyncIterator[ModelOutput]:
|
||||
"""Transforms the input stream by saving the outputs to cache.
|
||||
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
|
||||
"""Save the stream of model outputs to cache.
|
||||
|
||||
Transforms the input stream by saving the outputs to cache.
|
||||
|
||||
Args:
|
||||
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model outputs.
|
||||
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model
|
||||
outputs.
|
||||
|
||||
Returns:
|
||||
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are saved to cache.
|
||||
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are
|
||||
saved to cache.
|
||||
"""
|
||||
llm_cache_key: LLMCacheKey = None
|
||||
llm_cache_key: Optional[LLMCacheKey] = None
|
||||
outputs = []
|
||||
async for out in input_value:
|
||||
if not llm_cache_key:
|
||||
@@ -190,12 +211,13 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
"""
|
||||
|
||||
def __init__(self, cache_manager: CacheManager, **kwargs):
|
||||
"""Create a new instance of ModelSaveCacheOperator."""
|
||||
self._cache_manager = cache_manager
|
||||
self._client = LLMCacheClient(cache_manager)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def map(self, input_value: ModelOutput) -> ModelOutput:
|
||||
"""Saves a single model output to cache and returns it.
|
||||
"""Save model output to cache.
|
||||
|
||||
Args:
|
||||
input_value (ModelOutput): The output from the model to be cached.
|
||||
@@ -213,7 +235,7 @@ class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
|
||||
|
||||
|
||||
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
|
||||
"""Parses and extracts relevant fields from input to form a cache key dictionary.
|
||||
"""Parse and extract relevant fields from input to form a cache key dictionary.
|
||||
|
||||
Args:
|
||||
input_value (Dict): The input dictionary containing model and prompt parameters.
|
1
dbgpt/storage/cache/protocol/__init__.py
vendored
Normal file
1
dbgpt/storage/cache/protocol/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
"""Module for protocol."""
|
1
dbgpt/storage/cache/storage/__init__.py
vendored
1
dbgpt/storage/cache/storage/__init__.py
vendored
@@ -0,0 +1 @@
|
||||
"""Module for cache storage implementation."""
|
||||
|
26
dbgpt/storage/cache/storage/base.py
vendored
26
dbgpt/storage/cache/storage/base.py
vendored
@@ -1,3 +1,4 @@
|
||||
"""Base cache storage class."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import OrderedDict
|
||||
@@ -22,8 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class StorageItem:
|
||||
"""
|
||||
A class representing a storage item.
|
||||
"""A class representing a storage item.
|
||||
|
||||
This class encapsulates data related to a storage item, such as its length,
|
||||
the hash of the key, and the data for both the key and value.
|
||||
@@ -44,6 +44,7 @@ class StorageItem:
|
||||
def build_from(
|
||||
key_hash: bytes, key_data: bytes, value_data: bytes
|
||||
) -> "StorageItem":
|
||||
"""Build a StorageItem from the provided key and value data."""
|
||||
length = (
|
||||
32
|
||||
+ _get_object_bytes(key_hash)
|
||||
@@ -56,6 +57,7 @@ class StorageItem:
|
||||
|
||||
@staticmethod
|
||||
def build_from_kv(key: CacheKey[K], value: CacheValue[V]) -> "StorageItem":
|
||||
"""Build a StorageItem from the provided key and value."""
|
||||
key_hash = key.get_hash_bytes()
|
||||
key_data = key.serialize()
|
||||
value_data = value.serialize()
|
||||
@@ -105,6 +107,8 @@ class StorageItem:
|
||||
|
||||
|
||||
class CacheStorage(ABC):
|
||||
"""Base class for cache storage."""
|
||||
|
||||
@abstractmethod
|
||||
def check_config(
|
||||
self,
|
||||
@@ -122,6 +126,7 @@ class CacheStorage(ABC):
|
||||
"""
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Check whether the storage support async operation."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
@@ -135,7 +140,8 @@ class CacheStorage(ABC):
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If cache key not exist, return None.
|
||||
Optional[StorageItem]: The storage item retrieved according to key. If
|
||||
cache key not exist, return None.
|
||||
"""
|
||||
|
||||
async def aget(
|
||||
@@ -148,7 +154,8 @@ class CacheStorage(ABC):
|
||||
cache_config (Optional[CacheConfig]): Cache config
|
||||
|
||||
Returns:
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to key. If cache key not exist, return None.
|
||||
Optional[StorageItem]: The storage item of bytes retrieved according to
|
||||
key. If cache key not exist, return None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -184,8 +191,11 @@ class CacheStorage(ABC):
|
||||
|
||||
|
||||
class MemoryCacheStorage(CacheStorage):
|
||||
"""A simple in-memory cache storage implementation."""
|
||||
|
||||
def __init__(self, max_memory_mb: int = 256):
|
||||
self.cache = OrderedDict()
|
||||
"""Create a new instance of MemoryCacheStorage."""
|
||||
self.cache: OrderedDict = OrderedDict()
|
||||
self.max_memory = max_memory_mb * 1024 * 1024
|
||||
self.current_memory_usage = 0
|
||||
|
||||
@@ -194,6 +204,7 @@ class MemoryCacheStorage(CacheStorage):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal."""
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
@@ -208,10 +219,11 @@ class MemoryCacheStorage(CacheStorage):
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key."""
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
# Exact match retrieval
|
||||
key_hash = hash(key)
|
||||
item: StorageItem = self.cache.get(key_hash)
|
||||
item: Optional[StorageItem] = self.cache.get(key_hash)
|
||||
logger.debug(f"MemoryCacheStorage get key {key}, hash {key_hash}, item: {item}")
|
||||
|
||||
if not item:
|
||||
@@ -226,6 +238,7 @@ class MemoryCacheStorage(CacheStorage):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key."""
|
||||
key_hash = hash(key)
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
# Calculate memory size of the new entry
|
||||
@@ -242,6 +255,7 @@ class MemoryCacheStorage(CacheStorage):
|
||||
def exists(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> bool:
|
||||
"""Check if the key exists in the cache."""
|
||||
return self.get(key, cache_config) is not None
|
||||
|
||||
def _apply_cache_policy(self, cache_config: Optional[CacheConfig] = None):
|
||||
|
1
dbgpt/storage/cache/storage/disk/__init__.py
vendored
1
dbgpt/storage/cache/storage/disk/__init__.py
vendored
@@ -0,0 +1 @@
|
||||
"""Disk cache storage implementation."""
|
||||
|
22
dbgpt/storage/cache/storage/disk/disk_storage.py
vendored
22
dbgpt/storage/cache/storage/disk/disk_storage.py
vendored
@@ -1,3 +1,7 @@
|
||||
"""Disk storage for cache.
|
||||
|
||||
Implement the cache storage using rocksdb.
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -11,14 +15,14 @@ from dbgpt.core.interface.cache import (
|
||||
RetrievalPolicy,
|
||||
V,
|
||||
)
|
||||
from dbgpt.storage.cache.storage.base import CacheStorage, StorageItem
|
||||
|
||||
from ..base import CacheStorage, StorageItem
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def db_options(
|
||||
mem_table_buffer_mb: Optional[int] = 256, background_threads: Optional[int] = 2
|
||||
):
|
||||
def db_options(mem_table_buffer_mb: int = 256, background_threads: int = 2):
|
||||
"""Create rocksdb options."""
|
||||
opt = Options()
|
||||
# create table
|
||||
opt.create_if_missing(True)
|
||||
@@ -42,9 +46,10 @@ def db_options(
|
||||
|
||||
|
||||
class DiskCacheStorage(CacheStorage):
|
||||
def __init__(
|
||||
self, persist_dir: str, mem_table_buffer_mb: Optional[int] = 256
|
||||
) -> None:
|
||||
"""Disk cache storage using rocksdb."""
|
||||
|
||||
def __init__(self, persist_dir: str, mem_table_buffer_mb: int = 256) -> None:
|
||||
"""Create a new instance of DiskCacheStorage."""
|
||||
super().__init__()
|
||||
self.db: Rdict = Rdict(
|
||||
persist_dir, db_options(mem_table_buffer_mb=mem_table_buffer_mb)
|
||||
@@ -55,6 +60,7 @@ class DiskCacheStorage(CacheStorage):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
raise_error: Optional[bool] = True,
|
||||
) -> bool:
|
||||
"""Check whether the CacheConfig is legal."""
|
||||
if (
|
||||
cache_config
|
||||
and cache_config.retrieval_policy != RetrievalPolicy.EXACT_MATCH
|
||||
@@ -69,6 +75,7 @@ class DiskCacheStorage(CacheStorage):
|
||||
def get(
|
||||
self, key: CacheKey[K], cache_config: Optional[CacheConfig] = None
|
||||
) -> Optional[StorageItem]:
|
||||
"""Retrieve a storage item from the cache using the provided key."""
|
||||
self.check_config(cache_config, raise_error=True)
|
||||
|
||||
# Exact match retrieval
|
||||
@@ -86,6 +93,7 @@ class DiskCacheStorage(CacheStorage):
|
||||
value: CacheValue[V],
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
) -> None:
|
||||
"""Set a value in the cache for the provided key."""
|
||||
item = StorageItem.build_from_kv(key, value)
|
||||
key_hash = item.key_hash
|
||||
self.db[key_hash] = item.serialize()
|
||||
|
@@ -1,5 +1,3 @@
|
||||
import pytest
|
||||
|
||||
from dbgpt.util.memory_utils import _get_object_bytes
|
||||
|
||||
from ..base import StorageItem
|
||||
|
Reference in New Issue
Block a user