DB-GPT/dbgpt/storage/cache/operators.py
2024-03-15 15:42:46 +08:00

270 lines
9.8 KiB
Python

"""Operators for processing model outputs with caching support."""
import logging
from typing import AsyncIterator, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput, ModelRequest
from dbgpt.core.awel import (
BaseOperator,
BranchFunc,
BranchOperator,
MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from .manager import CacheManager
logger = logging.getLogger(__name__)
_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput]):
"""Operator for streaming processing of model outputs with caching.
Args:
cache_manager (CacheManager): The cache manager to handle caching operations.
**kwargs: Additional keyword arguments.
Methods:
streamify: Processes a stream of inputs with cache support, yielding model
outputs.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelStreamOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def streamify(self, input_value: ModelRequest):
"""Process inputs as a stream with cache support and yield model outputs.
Args:
input_value (ModelRequest): The input value for the model.
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
outputs = cast(List[ModelOutput], llm_cache_value.get_value().output)
for out in outputs:
yield cast(ModelOutput, out)
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
"""Operator for map-based processing of model outputs with caching.
Args:
cache_manager (CacheManager): Manager for caching operations.
**kwargs: Additional keyword arguments.
Methods:
map: Processes a single input with cache support and returns the model output.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def map(self, input_value: ModelRequest) -> ModelOutput:
"""Process a single input with cache support and return the model output.
Args:
input_value (ModelRequest): The input value for the model.
Returns:
ModelOutput: The output from the model.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value = await self._client.get(llm_cache_key)
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
logger.info(f"llm_cache_value: {llm_cache_value}")
return cast(ModelOutput, llm_cache_value.get_value().output)
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
"""Branch operator for model processing with cache support.
A branch operator that decides whether to use cached data or to process data using
the model.
Args:
cache_manager (CacheManager): The cache manager for managing cache operations.
model_task_name (str): The name of the task to process data using the model.
cache_task_name (str): The name of the task to process data using the cache.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
cache_manager: CacheManager,
model_task_name: str,
cache_task_name: str,
**kwargs,
):
"""Create a new instance of ModelCacheBranchOperator."""
super().__init__(branches=None, **kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
self._model_task_name = model_task_name
self._cache_task_name = cache_task_name
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""Branch logic based on cache availability.
Defines branch logic based on cache availability.
Returns:
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping
branch functions to task names.
"""
async def check_cache_true(input_value: ModelRequest) -> bool:
# Check if the cache contains the result for the given input
if input_value.context and not input_value.context.cache_enable:
return False
cache_dict = _parse_cache_key_dict(input_value)
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: "
f"{cache_value}"
)
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
)
return bool(cache_value)
async def check_cache_false(input_value: ModelRequest):
# Inverse of check_cache_true
return not await check_cache_true(input_value)
return {
check_cache_true: self._cache_task_name,
check_cache_false: self._model_task_name,
}
class ModelStreamSaveCacheOperator(
TransformStreamAbsOperator[ModelOutput, ModelOutput]
):
"""An operator to save the stream of model outputs to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelStreamSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
"""Save the stream of model outputs to cache.
Transforms the input stream by saving the outputs to cache.
Args:
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model
outputs.
Returns:
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are
saved to cache.
"""
llm_cache_key: Optional[LLMCacheKey] = None
outputs = []
async for out in input_value:
if not llm_cache_key:
llm_cache_key = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
outputs.append(out)
yield out
if llm_cache_key and _is_success_model_output(outputs):
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
await self._client.set(llm_cache_key, llm_cache_value)
class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
"""An operator to save a single model output to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Save model output to cache.
Args:
input_value (ModelOutput): The output from the model to be cached.
Returns:
ModelOutput: The same input model output.
"""
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
if llm_cache_key and _is_success_model_output(input_value):
await self._client.set(llm_cache_key, llm_cache_value)
return input_value
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
"""Parse and extract relevant fields from input to form a cache key dictionary.
Args:
input_value (Dict): The input dictionary containing model and prompt parameters.
Returns:
Dict: A dictionary used for generating cache keys.
"""
prompt: str = input_value.messages_to_string().strip()
return {
"prompt": prompt,
"model_name": input_value.model,
"temperature": input_value.temperature,
"max_new_tokens": input_value.max_new_tokens,
# "top_p": input_value.get("top_p", "1.0"),
# TODO pass model_type
# "model_type": input_value.get("model_type", "huggingface"),
}
def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
if not out:
return False
if isinstance(out, list):
# check last model output
out = out[-1]
error_code = 0
if isinstance(out, ModelOutput):
error_code = out.error_code
else:
error_code = int(out.get("error_code", 0))
return error_code == 0