diff --git a/.env.template b/.env.template index e03650033..ba7f752db 100644 --- a/.env.template +++ b/.env.template @@ -55,6 +55,17 @@ QUANTIZE_8bit=True ## Model path # llama_cpp_model_path=/data/models/TheBloke/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q4_K_M.gguf +### LLM cache +## Enable Model cache +# MODEL_CACHE_ENABLE=True +## The storage type of model cache, now supports: memory, disk +# MODEL_CACHE_STORAGE_TYPE=disk +## The max cache data in memory, we always store cache data in memory fist for high speed. +# MODEL_CACHE_MAX_MEMORY_MB=256 +## The dir to save cache data, this configuration is only valid when MODEL_CACHE_STORAGE_TYPE=disk +## The default dir is pilot/data/model_cache +# MODEL_CACHE_STORAGE_DISK_DIR= + #*******************************************************************# #** EMBEDDING SETTINGS **# #*******************************************************************# diff --git a/pilot/cache/manager.py b/pilot/cache/manager.py index 093b52bdd..0e76df0b3 100644 --- a/pilot/cache/manager.py +++ b/pilot/cache/manager.py @@ -100,19 +100,27 @@ class LocalCacheManager(CacheManager): return self._serializer -def initialize_cache(system_app: SystemApp, persist_dir: str): +def initialize_cache( + system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str +): from pilot.cache.protocal.json_protocal import JsonSerializer from pilot.cache.storage.base import MemoryCacheStorage - try: - from pilot.cache.storage.disk.disk_storage import DiskCacheStorage + cache_storage = None + if storage_type == "disk": + try: + from pilot.cache.storage.disk.disk_storage import DiskCacheStorage - cache_storage = DiskCacheStorage(persist_dir) - except ImportError as e: - logger.warn( - f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}" - ) - cache_storage = MemoryCacheStorage() + cache_storage = DiskCacheStorage( + persist_dir, mem_table_buffer_mb=max_memory_mb + ) + except ImportError as e: + logger.warn( + f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}" + ) + cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb) + else: + cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb) system_app.register( LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage ) diff --git a/pilot/cache/storage/base.py b/pilot/cache/storage/base.py index 1e1fdde80..ea07bfacf 100644 --- a/pilot/cache/storage/base.py +++ b/pilot/cache/storage/base.py @@ -184,7 +184,7 @@ class CacheStorage(ABC): class MemoryCacheStorage(CacheStorage): - def __init__(self, max_memory_mb: int = 1024): + def __init__(self, max_memory_mb: int = 256): self.cache = OrderedDict() self.max_memory = max_memory_mb * 1024 * 1024 self.current_memory_usage = 0 diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 9755c825d..f93cd7b83 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -253,9 +253,17 @@ class Config(metaclass=Singleton): ### Temporary configuration self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true" - self.MODEL_CACHE_STORAGE: str = os.getenv("MODEL_CACHE_STORAGE") - self.MODEL_CACHE_STORAGE_DIST_DIR: str = os.getenv( - "MODEL_CACHE_STORAGE_DIST_DIR" + self.MODEL_CACHE_ENABLE: bool = ( + os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true" + ) + self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv( + "MODEL_CACHE_STORAGE_TYPE", "disk" + ) + self.MODEL_CACHE_MAX_MEMORY_MB: int = int( + os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256) + ) + self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv( + "MODEL_CACHE_STORAGE_DISK_DIR" ) def set_debug_mode(self, value: bool) -> None: diff --git a/pilot/server/component_configs.py b/pilot/server/component_configs.py index 7247d14dc..58269385b 100644 --- a/pilot/server/component_configs.py +++ b/pilot/server/component_configs.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Type import os from pilot.component import ComponentType, SystemApp +from pilot.configs.config import Config from pilot.configs.model_config import MODEL_DISK_CACHE_DIR from pilot.utils.executor_utils import DefaultExecutorFactory from pilot.embedding_engine.embedding_factory import EmbeddingFactory @@ -16,6 +17,8 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +CFG = Config() + def initialize_components( param: WebWerverParameters, @@ -41,10 +44,7 @@ def initialize_components( _initialize_embedding_model( param, system_app, embedding_model_name, embedding_model_path ) - - from pilot.cache import initialize_cache - - initialize_cache(system_app, MODEL_DISK_CACHE_DIR) + _initialize_model_cache(system_app) def _initialize_embedding_model( @@ -136,3 +136,16 @@ class LocalEmbeddingFactory(EmbeddingFactory): loader = EmbeddingLoader() # Ignore model_name args return loader.load(self._default_model_name, model_params) + + +def _initialize_model_cache(system_app: SystemApp): + from pilot.cache import initialize_cache + + if not CFG.MODEL_CACHE_ENABLE: + logger.info("Model cache is not enable") + return + + storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk" + max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256 + persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR + initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)