DB-GPT/dbgpt/storage/cache/llm_cache.py
2024-03-15 15:42:46 +08:00

197 lines
6.0 KiB
Python

"""Cache client for LLM."""
import hashlib
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
from .manager import CacheManager
@dataclass
class LLMCacheKeyData:
"""Cache key data for LLM."""
prompt: str
model_name: str
temperature: Optional[float] = 0.7
max_new_tokens: Optional[int] = None
top_p: Optional[float] = 1.0
# See dbgpt.model.base.ModelType
model_type: Optional[str] = "huggingface"
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
@dataclass
class LLMCacheValueData:
"""Cache value data for LLM."""
output: CacheOutputType
user: Optional[str] = None
_is_list: bool = False
@staticmethod
def from_dict(**kwargs) -> "LLMCacheValueData":
"""Create LLMCacheValueData object from dict."""
output = kwargs.get("output")
if not output:
raise ValueError("Can't new LLMCacheValueData object, output is None")
if isinstance(output, dict):
output = ModelOutput(**output)
elif isinstance(output, list):
kwargs["_is_list"] = True
output_list = []
for out in output:
if isinstance(out, dict):
out = ModelOutput(**out)
output_list.append(out)
output = output_list
kwargs["output"] = output
return LLMCacheValueData(**kwargs)
def to_dict(self) -> Dict:
"""Convert to dict."""
output = self.output
is_list = False
if isinstance(output, list):
output_list = []
is_list = True
for out in output:
output_list.append(out.to_dict())
output = output_list # type: ignore
else:
output = output.to_dict() # type: ignore
return {"output": output, "_is_list": is_list, "user": self.user}
@property
def is_list(self) -> bool:
"""Return whether the output is a list."""
return self._is_list
def __str__(self) -> str:
"""Return string representation."""
if not isinstance(self.output, list):
return f"user: {self.user}, output: {self.output}"
else:
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
"""Cache key for LLM."""
def __init__(self, **kwargs) -> None:
"""Create a new instance of LLMCacheKey."""
super().__init__()
self.config = LLMCacheKeyData(**kwargs)
def __hash__(self) -> int:
"""Return the hash value of the object."""
serialize_bytes = self.serialize()
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
def __eq__(self, other: Any) -> bool:
"""Check equality with another key."""
if not isinstance(other, LLMCacheKey):
return False
return self.config == other.config
def get_hash_bytes(self) -> bytes:
"""Return the byte array of hash value.
Returns:
bytes: The byte array of hash value.
"""
serialize_bytes = self.serialize()
return hashlib.sha256(serialize_bytes).digest()
def to_dict(self) -> Dict:
"""Convert to dict."""
return asdict(self.config)
def get_value(self) -> LLMCacheKeyData:
"""Return the real object of current cache key."""
return self.config
class LLMCacheValue(CacheValue[LLMCacheValueData]):
"""Cache value for LLM."""
def __init__(self, **kwargs) -> None:
"""Create a new instance of LLMCacheValue."""
super().__init__()
self.value = LLMCacheValueData.from_dict(**kwargs)
def to_dict(self) -> Dict:
"""Convert to dict."""
return self.value.to_dict()
def get_value(self) -> LLMCacheValueData:
"""Return the underlying real value."""
return self.value
def __str__(self) -> str:
"""Return string representation."""
return f"value: {str(self.value)}"
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
"""Cache client for LLM."""
def __init__(self, cache_manager: CacheManager) -> None:
"""Create a new instance of LLMCacheClient."""
super().__init__()
self._cache_manager: CacheManager = cache_manager
async def get(
self,
key: LLMCacheKey, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> Optional[LLMCacheValue]:
"""Retrieve a value from the cache using the provided key.
Args:
key (LLMCacheKey): The key to get cache
cache_config (Optional[CacheConfig]): Cache config
Returns:
Optional[LLMCacheValue]: The value retrieved according to key. If cache key
not exist, return None.
"""
return cast(
LLMCacheValue,
await self._cache_manager.get(key, LLMCacheValue, cache_config),
)
async def set(
self,
key: LLMCacheKey, # type: ignore
value: LLMCacheValue, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> None:
"""Set a value in the cache for the provided key."""
return await self._cache_manager.set(key, value, cache_config)
async def exists(
self,
key: LLMCacheKey, # type: ignore
cache_config: Optional[CacheConfig] = None,
) -> bool:
"""Check if a key exists in the cache."""
return await self.get(key, cache_config) is not None
def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore
"""Create a cache key with params."""
key = LLMCacheKey(**kwargs)
key.set_serializer(self._cache_manager.serializer)
return key
def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore
"""Create a cache value with params."""
value = LLMCacheValue(**kwargs)
value.set_serializer(self._cache_manager.serializer)
return value