chore: Add pylint for storage (#1298)

This commit is contained in:
Fangyin Cheng
2024-03-15 15:42:46 +08:00
committed by GitHub
parent a207640ff2
commit 8897d6e8fd
50 changed files with 784 additions and 667 deletions

269
dbgpt/storage/cache/operators.py vendored Normal file
View File

@@ -0,0 +1,269 @@
"""Operators for processing model outputs with caching support."""
import logging
from typing import AsyncIterator, Dict, List, Optional, Union, cast
from dbgpt.core import ModelOutput, ModelRequest
from dbgpt.core.awel import (
BaseOperator,
BranchFunc,
BranchOperator,
MapOperator,
StreamifyAbsOperator,
TransformStreamAbsOperator,
)
from .llm_cache import LLMCacheClient, LLMCacheKey, LLMCacheValue
from .manager import CacheManager
logger = logging.getLogger(__name__)
_LLM_MODEL_INPUT_VALUE_KEY = "llm_model_input_value"
_LLM_MODEL_OUTPUT_CACHE_KEY = "llm_model_output_cache"
class CachedModelStreamOperator(StreamifyAbsOperator[ModelRequest, ModelOutput]):
"""Operator for streaming processing of model outputs with caching.
Args:
cache_manager (CacheManager): The cache manager to handle caching operations.
**kwargs: Additional keyword arguments.
Methods:
streamify: Processes a stream of inputs with cache support, yielding model
outputs.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelStreamOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def streamify(self, input_value: ModelRequest):
"""Process inputs as a stream with cache support and yield model outputs.
Args:
input_value (ModelRequest): The input value for the model.
Returns:
AsyncIterator[ModelOutput]: An asynchronous iterator of model outputs.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value = await self._client.get(llm_cache_key)
logger.info(f"llm_cache_value: {llm_cache_value}")
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
outputs = cast(List[ModelOutput], llm_cache_value.get_value().output)
for out in outputs:
yield cast(ModelOutput, out)
class CachedModelOperator(MapOperator[ModelRequest, ModelOutput]):
"""Operator for map-based processing of model outputs with caching.
Args:
cache_manager (CacheManager): Manager for caching operations.
**kwargs: Additional keyword arguments.
Methods:
map: Processes a single input with cache support and returns the model output.
"""
def __init__(self, cache_manager: CacheManager, **kwargs) -> None:
"""Create a new instance of CachedModelOperator."""
super().__init__(**kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
async def map(self, input_value: ModelRequest) -> ModelOutput:
"""Process a single input with cache support and return the model output.
Args:
input_value (ModelRequest): The input value for the model.
Returns:
ModelOutput: The output from the model.
"""
cache_dict = _parse_cache_key_dict(input_value)
llm_cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
llm_cache_value = await self._client.get(llm_cache_key)
if not llm_cache_value:
raise ValueError(f"Cache value not found for key: {llm_cache_key}")
logger.info(f"llm_cache_value: {llm_cache_value}")
return cast(ModelOutput, llm_cache_value.get_value().output)
class ModelCacheBranchOperator(BranchOperator[ModelRequest, Dict]):
"""Branch operator for model processing with cache support.
A branch operator that decides whether to use cached data or to process data using
the model.
Args:
cache_manager (CacheManager): The cache manager for managing cache operations.
model_task_name (str): The name of the task to process data using the model.
cache_task_name (str): The name of the task to process data using the cache.
**kwargs: Additional keyword arguments.
"""
def __init__(
self,
cache_manager: CacheManager,
model_task_name: str,
cache_task_name: str,
**kwargs,
):
"""Create a new instance of ModelCacheBranchOperator."""
super().__init__(branches=None, **kwargs)
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
self._model_task_name = model_task_name
self._cache_task_name = cache_task_name
async def branches(
self,
) -> Dict[BranchFunc[ModelRequest], Union[BaseOperator, str]]:
"""Branch logic based on cache availability.
Defines branch logic based on cache availability.
Returns:
Dict[BranchFunc[Dict], Union[BaseOperator, str]]: A dictionary mapping
branch functions to task names.
"""
async def check_cache_true(input_value: ModelRequest) -> bool:
# Check if the cache contains the result for the given input
if input_value.context and not input_value.context.cache_enable:
return False
cache_dict = _parse_cache_key_dict(input_value)
cache_key: LLMCacheKey = self._client.new_key(**cache_dict)
cache_value = await self._client.get(cache_key)
logger.debug(
f"cache_key: {cache_key}, hash key: {hash(cache_key)}, cache_value: "
f"{cache_value}"
)
await self.current_dag_context.save_to_share_data(
_LLM_MODEL_INPUT_VALUE_KEY, cache_key, overwrite=True
)
return bool(cache_value)
async def check_cache_false(input_value: ModelRequest):
# Inverse of check_cache_true
return not await check_cache_true(input_value)
return {
check_cache_true: self._cache_task_name,
check_cache_false: self._model_task_name,
}
class ModelStreamSaveCacheOperator(
TransformStreamAbsOperator[ModelOutput, ModelOutput]
):
"""An operator to save the stream of model outputs to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelStreamSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def transform_stream(self, input_value: AsyncIterator[ModelOutput]):
"""Save the stream of model outputs to cache.
Transforms the input stream by saving the outputs to cache.
Args:
input_value (AsyncIterator[ModelOutput]): An asynchronous iterator of model
outputs.
Returns:
AsyncIterator[ModelOutput]: The same input iterator, but the outputs are
saved to cache.
"""
llm_cache_key: Optional[LLMCacheKey] = None
outputs = []
async for out in input_value:
if not llm_cache_key:
llm_cache_key = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
outputs.append(out)
yield out
if llm_cache_key and _is_success_model_output(outputs):
llm_cache_value: LLMCacheValue = self._client.new_value(output=outputs)
await self._client.set(llm_cache_key, llm_cache_value)
class ModelSaveCacheOperator(MapOperator[ModelOutput, ModelOutput]):
"""An operator to save a single model output to cache.
Args:
cache_manager (CacheManager): The cache manager for handling cache operations.
**kwargs: Additional keyword arguments.
"""
def __init__(self, cache_manager: CacheManager, **kwargs):
"""Create a new instance of ModelSaveCacheOperator."""
self._cache_manager = cache_manager
self._client = LLMCacheClient(cache_manager)
super().__init__(**kwargs)
async def map(self, input_value: ModelOutput) -> ModelOutput:
"""Save model output to cache.
Args:
input_value (ModelOutput): The output from the model to be cached.
Returns:
ModelOutput: The same input model output.
"""
llm_cache_key: LLMCacheKey = await self.current_dag_context.get_from_share_data(
_LLM_MODEL_INPUT_VALUE_KEY
)
llm_cache_value: LLMCacheValue = self._client.new_value(output=input_value)
if llm_cache_key and _is_success_model_output(input_value):
await self._client.set(llm_cache_key, llm_cache_value)
return input_value
def _parse_cache_key_dict(input_value: ModelRequest) -> Dict:
"""Parse and extract relevant fields from input to form a cache key dictionary.
Args:
input_value (Dict): The input dictionary containing model and prompt parameters.
Returns:
Dict: A dictionary used for generating cache keys.
"""
prompt: str = input_value.messages_to_string().strip()
return {
"prompt": prompt,
"model_name": input_value.model,
"temperature": input_value.temperature,
"max_new_tokens": input_value.max_new_tokens,
# "top_p": input_value.get("top_p", "1.0"),
# TODO pass model_type
# "model_type": input_value.get("model_type", "huggingface"),
}
def _is_success_model_output(out: Union[Dict, ModelOutput, List[ModelOutput]]) -> bool:
if not out:
return False
if isinstance(out, list):
# check last model output
out = out[-1]
error_code = 0
if isinstance(out, ModelOutput):
error_code = out.error_code
else:
error_code = int(out.get("error_code", 0))
return error_code == 0