mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 22:07:48 +00:00
197 lines
6.0 KiB
Python
197 lines
6.0 KiB
Python
"""Cache client for LLM."""
|
|
|
|
import hashlib
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Any, Dict, List, Optional, Union, cast
|
|
|
|
from dbgpt.core import ModelOutput
|
|
from dbgpt.core.interface.cache import CacheClient, CacheConfig, CacheKey, CacheValue
|
|
|
|
from .manager import CacheManager
|
|
|
|
|
|
@dataclass
|
|
class LLMCacheKeyData:
|
|
"""Cache key data for LLM."""
|
|
|
|
prompt: str
|
|
model_name: str
|
|
temperature: Optional[float] = 0.7
|
|
max_new_tokens: Optional[int] = None
|
|
top_p: Optional[float] = 1.0
|
|
# See dbgpt.model.base.ModelType
|
|
model_type: Optional[str] = "huggingface"
|
|
|
|
|
|
CacheOutputType = Union[ModelOutput, List[ModelOutput]]
|
|
|
|
|
|
@dataclass
|
|
class LLMCacheValueData:
|
|
"""Cache value data for LLM."""
|
|
|
|
output: CacheOutputType
|
|
user: Optional[str] = None
|
|
_is_list: bool = False
|
|
|
|
@staticmethod
|
|
def from_dict(**kwargs) -> "LLMCacheValueData":
|
|
"""Create LLMCacheValueData object from dict."""
|
|
output = kwargs.get("output")
|
|
if not output:
|
|
raise ValueError("Can't new LLMCacheValueData object, output is None")
|
|
if isinstance(output, dict):
|
|
output = ModelOutput(**output)
|
|
elif isinstance(output, list):
|
|
kwargs["_is_list"] = True
|
|
output_list = []
|
|
for out in output:
|
|
if isinstance(out, dict):
|
|
out = ModelOutput(**out)
|
|
output_list.append(out)
|
|
output = output_list
|
|
kwargs["output"] = output
|
|
return LLMCacheValueData(**kwargs)
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert to dict."""
|
|
output = self.output
|
|
is_list = False
|
|
if isinstance(output, list):
|
|
output_list = []
|
|
is_list = True
|
|
for out in output:
|
|
output_list.append(out.to_dict())
|
|
output = output_list # type: ignore
|
|
else:
|
|
output = output.to_dict() # type: ignore
|
|
return {"output": output, "_is_list": is_list, "user": self.user}
|
|
|
|
@property
|
|
def is_list(self) -> bool:
|
|
"""Return whether the output is a list."""
|
|
return self._is_list
|
|
|
|
def __str__(self) -> str:
|
|
"""Return string representation."""
|
|
if not isinstance(self.output, list):
|
|
return f"user: {self.user}, output: {self.output}"
|
|
else:
|
|
return f"user: {self.user}, output(last two item): {self.output[-2:]}"
|
|
|
|
|
|
class LLMCacheKey(CacheKey[LLMCacheKeyData]):
|
|
"""Cache key for LLM."""
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
"""Create a new instance of LLMCacheKey."""
|
|
super().__init__()
|
|
self.config = LLMCacheKeyData(**kwargs)
|
|
|
|
def __hash__(self) -> int:
|
|
"""Return the hash value of the object."""
|
|
serialize_bytes = self.serialize()
|
|
return int(hashlib.sha256(serialize_bytes).hexdigest(), 16)
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
"""Check equality with another key."""
|
|
if not isinstance(other, LLMCacheKey):
|
|
return False
|
|
return self.config == other.config
|
|
|
|
def get_hash_bytes(self) -> bytes:
|
|
"""Return the byte array of hash value.
|
|
|
|
Returns:
|
|
bytes: The byte array of hash value.
|
|
"""
|
|
serialize_bytes = self.serialize()
|
|
return hashlib.sha256(serialize_bytes).digest()
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert to dict."""
|
|
return asdict(self.config)
|
|
|
|
def get_value(self) -> LLMCacheKeyData:
|
|
"""Return the real object of current cache key."""
|
|
return self.config
|
|
|
|
|
|
class LLMCacheValue(CacheValue[LLMCacheValueData]):
|
|
"""Cache value for LLM."""
|
|
|
|
def __init__(self, **kwargs) -> None:
|
|
"""Create a new instance of LLMCacheValue."""
|
|
super().__init__()
|
|
self.value = LLMCacheValueData.from_dict(**kwargs)
|
|
|
|
def to_dict(self) -> Dict:
|
|
"""Convert to dict."""
|
|
return self.value.to_dict()
|
|
|
|
def get_value(self) -> LLMCacheValueData:
|
|
"""Return the underlying real value."""
|
|
return self.value
|
|
|
|
def __str__(self) -> str:
|
|
"""Return string representation."""
|
|
return f"value: {str(self.value)}"
|
|
|
|
|
|
class LLMCacheClient(CacheClient[LLMCacheKeyData, LLMCacheValueData]):
|
|
"""Cache client for LLM."""
|
|
|
|
def __init__(self, cache_manager: CacheManager) -> None:
|
|
"""Create a new instance of LLMCacheClient."""
|
|
super().__init__()
|
|
self._cache_manager: CacheManager = cache_manager
|
|
|
|
async def get(
|
|
self,
|
|
key: LLMCacheKey, # type: ignore
|
|
cache_config: Optional[CacheConfig] = None,
|
|
) -> Optional[LLMCacheValue]:
|
|
"""Retrieve a value from the cache using the provided key.
|
|
|
|
Args:
|
|
key (LLMCacheKey): The key to get cache
|
|
cache_config (Optional[CacheConfig]): Cache config
|
|
|
|
Returns:
|
|
Optional[LLMCacheValue]: The value retrieved according to key. If cache key
|
|
not exist, return None.
|
|
"""
|
|
return cast(
|
|
LLMCacheValue,
|
|
await self._cache_manager.get(key, LLMCacheValue, cache_config),
|
|
)
|
|
|
|
async def set(
|
|
self,
|
|
key: LLMCacheKey, # type: ignore
|
|
value: LLMCacheValue, # type: ignore
|
|
cache_config: Optional[CacheConfig] = None,
|
|
) -> None:
|
|
"""Set a value in the cache for the provided key."""
|
|
return await self._cache_manager.set(key, value, cache_config)
|
|
|
|
async def exists(
|
|
self,
|
|
key: LLMCacheKey, # type: ignore
|
|
cache_config: Optional[CacheConfig] = None,
|
|
) -> bool:
|
|
"""Check if a key exists in the cache."""
|
|
return await self.get(key, cache_config) is not None
|
|
|
|
def new_key(self, **kwargs) -> LLMCacheKey: # type: ignore
|
|
"""Create a cache key with params."""
|
|
key = LLMCacheKey(**kwargs)
|
|
key.set_serializer(self._cache_manager.serializer)
|
|
return key
|
|
|
|
def new_value(self, **kwargs) -> LLMCacheValue: # type: ignore
|
|
"""Create a cache value with params."""
|
|
value = LLMCacheValue(**kwargs)
|
|
value.set_serializer(self._cache_manager.serializer)
|
|
return value
|