mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-01 09:06:55 +00:00
feat(cache): Support configures the model cache in .env
This commit is contained in:
@@ -55,6 +55,17 @@ QUANTIZE_8bit=True
|
||||
## Model path
|
||||
# llama_cpp_model_path=/data/models/TheBloke/vicuna-13B-v1.5-GGUF/vicuna-13b-v1.5.Q4_K_M.gguf
|
||||
|
||||
### LLM cache
|
||||
## Enable Model cache
|
||||
# MODEL_CACHE_ENABLE=True
|
||||
## The storage type of model cache, now supports: memory, disk
|
||||
# MODEL_CACHE_STORAGE_TYPE=disk
|
||||
## The max cache data in memory, we always store cache data in memory fist for high speed.
|
||||
# MODEL_CACHE_MAX_MEMORY_MB=256
|
||||
## The dir to save cache data, this configuration is only valid when MODEL_CACHE_STORAGE_TYPE=disk
|
||||
## The default dir is pilot/data/model_cache
|
||||
# MODEL_CACHE_STORAGE_DISK_DIR=
|
||||
|
||||
#*******************************************************************#
|
||||
#** EMBEDDING SETTINGS **#
|
||||
#*******************************************************************#
|
||||
|
26
pilot/cache/manager.py
vendored
26
pilot/cache/manager.py
vendored
@@ -100,19 +100,27 @@ class LocalCacheManager(CacheManager):
|
||||
return self._serializer
|
||||
|
||||
|
||||
def initialize_cache(system_app: SystemApp, persist_dir: str):
|
||||
def initialize_cache(
|
||||
system_app: SystemApp, storage_type: str, max_memory_mb: int, persist_dir: str
|
||||
):
|
||||
from pilot.cache.protocal.json_protocal import JsonSerializer
|
||||
from pilot.cache.storage.base import MemoryCacheStorage
|
||||
|
||||
try:
|
||||
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
cache_storage = None
|
||||
if storage_type == "disk":
|
||||
try:
|
||||
from pilot.cache.storage.disk.disk_storage import DiskCacheStorage
|
||||
|
||||
cache_storage = DiskCacheStorage(persist_dir)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage()
|
||||
cache_storage = DiskCacheStorage(
|
||||
persist_dir, mem_table_buffer_mb=max_memory_mb
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warn(
|
||||
f"Can't import DiskCacheStorage, use MemoryCacheStorage, import error message: {str(e)}"
|
||||
)
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
else:
|
||||
cache_storage = MemoryCacheStorage(max_memory_mb=max_memory_mb)
|
||||
system_app.register(
|
||||
LocalCacheManager, serializer=JsonSerializer(), storage=cache_storage
|
||||
)
|
||||
|
2
pilot/cache/storage/base.py
vendored
2
pilot/cache/storage/base.py
vendored
@@ -184,7 +184,7 @@ class CacheStorage(ABC):
|
||||
|
||||
|
||||
class MemoryCacheStorage(CacheStorage):
|
||||
def __init__(self, max_memory_mb: int = 1024):
|
||||
def __init__(self, max_memory_mb: int = 256):
|
||||
self.cache = OrderedDict()
|
||||
self.max_memory = max_memory_mb * 1024 * 1024
|
||||
self.current_memory_usage = 0
|
||||
|
@@ -253,9 +253,17 @@ class Config(metaclass=Singleton):
|
||||
### Temporary configuration
|
||||
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
||||
|
||||
self.MODEL_CACHE_STORAGE: str = os.getenv("MODEL_CACHE_STORAGE")
|
||||
self.MODEL_CACHE_STORAGE_DIST_DIR: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_DIST_DIR"
|
||||
self.MODEL_CACHE_ENABLE: bool = (
|
||||
os.getenv("MODEL_CACHE_ENABLE", "True").lower() == "true"
|
||||
)
|
||||
self.MODEL_CACHE_STORAGE_TYPE: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_TYPE", "disk"
|
||||
)
|
||||
self.MODEL_CACHE_MAX_MEMORY_MB: int = int(
|
||||
os.getenv("MODEL_CACHE_MAX_MEMORY_MB", 256)
|
||||
)
|
||||
self.MODEL_CACHE_STORAGE_DISK_DIR: str = os.getenv(
|
||||
"MODEL_CACHE_STORAGE_DISK_DIR"
|
||||
)
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Type
|
||||
import os
|
||||
|
||||
from pilot.component import ComponentType, SystemApp
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import MODEL_DISK_CACHE_DIR
|
||||
from pilot.utils.executor_utils import DefaultExecutorFactory
|
||||
from pilot.embedding_engine.embedding_factory import EmbeddingFactory
|
||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def initialize_components(
|
||||
param: WebWerverParameters,
|
||||
@@ -41,10 +44,7 @@ def initialize_components(
|
||||
_initialize_embedding_model(
|
||||
param, system_app, embedding_model_name, embedding_model_path
|
||||
)
|
||||
|
||||
from pilot.cache import initialize_cache
|
||||
|
||||
initialize_cache(system_app, MODEL_DISK_CACHE_DIR)
|
||||
_initialize_model_cache(system_app)
|
||||
|
||||
|
||||
def _initialize_embedding_model(
|
||||
@@ -136,3 +136,16 @@ class LocalEmbeddingFactory(EmbeddingFactory):
|
||||
loader = EmbeddingLoader()
|
||||
# Ignore model_name args
|
||||
return loader.load(self._default_model_name, model_params)
|
||||
|
||||
|
||||
def _initialize_model_cache(system_app: SystemApp):
|
||||
from pilot.cache import initialize_cache
|
||||
|
||||
if not CFG.MODEL_CACHE_ENABLE:
|
||||
logger.info("Model cache is not enable")
|
||||
return
|
||||
|
||||
storage_type = CFG.MODEL_CACHE_STORAGE_TYPE or "disk"
|
||||
max_memory_mb = CFG.MODEL_CACHE_MAX_MEMORY_MB or 256
|
||||
persist_dir = CFG.MODEL_CACHE_STORAGE_DISK_DIR or MODEL_DISK_CACHE_DIR
|
||||
initialize_cache(system_app, storage_type, max_memory_mb, persist_dir)
|
||||
|
Reference in New Issue
Block a user