Files
DB-GPT/dbgpt/storage/cache/llm_cache.py
2024-01-10 10:39:04 +08:00

144 lines
4.4 KiB
Python

import hashlib
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Union
from dbgpt.core import ModelOutput, Serializer
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
from dbgpt.model.base import ModelType
from dbgpt.storage.cache.manager import CacheManager
@dataclass
class LLMCacheKeyData:
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
model_type: Optional[str] = ModelType.HF
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
output: CacheOutputType
user: Optional[str] = None
_is_list: Optional[bool] = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
if isinstance(output, dict):
output = ModelOutput(**output)
elif isinstance(output, list):
kwargs["_is_list"] = True
output_list = []
for out in output:
if isinstance(out, dict):
out = ModelOutput(**out)
output_list.append(out)
output = output_list
kwargs["output"] = output
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
output = self.output
is_list = False
if isinstance(output, list):
output_list = []
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list
else:
output = output.to_dict()
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
return self._is_list
def __str__(self) -> str:
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
def __init__(self, **kwargs) -> None:
super().__init__()
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
return asdict(self.config)
def get_value(self) -> LLMCacheKeyData:
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
def __init__(self, **kwargs) -> None:
super().__init__()
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
return self.value.to_dict()
def get_value(self) -> LLMCacheValueData:
return self.value
def __str__(self) -> str:
return f"value: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
def __init__(self, cache_manager: CacheManager) -> None:
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> Optional[LLMCacheValue]:
return await self._cache_manager.get(key, LLMCacheValue, cache_config)
async def set(
self,
key: LLMCacheKey,
value: LLMCacheValue,
cache_config: Optional[CacheConfig] = None,
) -> None:
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self, key: LLMCacheKey, cache_config: Optional[CacheConfig] = None
) -> bool:
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey:
key = LLMCacheKey(**kwargs)
key.set_serializer(self._cache_manager.serializer)
return key
def new_value(self, **kwargs) -> LLMCacheValue:
value = LLMCacheValue(**kwargs)
value.set_serializer(self._cache_manager.serializer)
return value