feat(cache): Support configures the model cache in .env

This commit is contained in:
FangYin Cheng
2023-11-16 18:46:00 +08:00
parent 1150adbe6a
commit 995772077c
5 changed files with 57 additions and 17 deletions

View File

@@ -55,6 +55,17 @@ QUANTIZE_8bit=True
## Model path
# llama_cpp_model_path=/data/models/TheBloke/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q4_K_M.gguf
### LLM cache
## Enable Model cache
# MODEL_CACHE_ENABLE=True
## The storage type of model cache, now supports: memory, disk
# MODEL_CACHE_STORAGE_TYPE=disk
## The max cache data in memory, we always store cache data in memory fist for high speed.
# MODEL_CACHE_MAX_MEMORY_MB=256
## The dir to save cache data, this configuration is only valid when MODEL_CACHE_STORAGE_TYPE=disk
## The default dir is pilot/data/model_cache
# MODEL_CACHE_STORAGE_DISK_DIR=
#*******************************************************************#
#** EMBEDDING SETTINGS **#
#*******************************************************************#

View File

@@ -100,19 +100,27 @@ class LocalCacheManager(CacheManager):
return self._serializer
def initialize_cache(system_app: SystemApp, persist_dir: str):
def initialize_cache(
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
):
from pilot.cache.protocal.json_protocal import JsonSerializer
from pilot.cache.storage.base import MemoryCacheStorage
try:
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
cache_storage = None
if storage_type == "disk":
try:
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
cache_storage = DiskCacheStorage(persist_dir)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
)
cache_storage = MemoryCacheStorage()
cache_storage = DiskCacheStorage(
persist_dir, mem_table_buffer_mb=max_memory_mb
)
except ImportError as e:
logger.warn(
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
)
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
else:
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
system_app.register(
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
)

View File

@@ -184,7 +184,7 @@ class CacheStorage(ABC):
class MemoryCacheStorage(CacheStorage):
def __init__(self, max_memory_mb: int = 1024):
def __init__(self, max_memory_mb: int = 256):
self.cache = OrderedDict()
self.max_memory = max_memory_mb * 1024 * 1024
self.current_memory_usage = 0

View File

@@ -253,9 +253,17 @@ class Config(metaclass=Singleton):
### Temporary configuration
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.MODEL_CACHE_STORAGE: str = os.getenv("MODEL_CACHE_STORAGE")
self.MODEL_CACHE_STORAGE_DIST_DIR: str = os.getenv(
"MODEL_CACHE_STORAGE_DIST_DIR"
self.MODEL_CACHE_ENABLE: bool = (
os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true"
)
self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv(
"MODEL_CACHE_STORAGE_TYPE", "disk"
)
self.MODEL_CACHE_MAX_MEMORY_MB: int = int(
os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256)
)
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
"MODEL_CACHE_STORAGE_DISK_DIR"
)
def set_debug_mode(self, value: bool) -> None:

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Type
import os
from pilot.component import ComponentType, SystemApp
from pilot.configs.config import Config
from pilot.configs.model_config import MODEL_DISK_CACHE_DIR
from pilot.utils.executor_utils import DefaultExecutorFactory
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
CFG = Config()
def initialize_components(
param: WebWerverParameters,
@@ -41,10 +44,7 @@ def initialize_components(
_initialize_embedding_model(
param, system_app, embedding_model_name, embedding_model_path
)
from pilot.cache import initialize_cache
initialize_cache(system_app, MODEL_DISK_CACHE_DIR)
_initialize_model_cache(system_app)
def _initialize_embedding_model(
@@ -136,3 +136,16 @@ class LocalEmbeddingFactory(EmbeddingFactory):
loader = EmbeddingLoader()
# Ignore model_name args
return loader.load(self._default_model_name, model_params)
def _initialize_model_cache(system_app: SystemApp):
from pilot.cache import initialize_cache
if not CFG.MODEL_CACHE_ENABLE:
logger.info("Model cache is not enable")
return
storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk"
max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)