chore: Add _clear_model_cache function to clear model cache

This commit is contained in:
FangYin Cheng
2023-09-21 18:19:19 +08:00
parent 461179ee6f
commit 28c24516a2
3 changed files with 21 additions and 11 deletions

View File

@@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
del self.tokenizer
self.model = None
self.tokenizer = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False

View File

@@ -11,7 +11,7 @@ from pilot.model.parameter import (
)
from pilot.model.cluster.worker_base import ModelWorker
from pilot.model.cluster.embedding.loader import EmbeddingLoader
from pilot.utils.model_utils import _clear_torch_cache
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
@@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
return
del self._embeddings_impl
self._embeddings_impl = None
_clear_torch_cache(self._model_params.device)
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict):
"""Generate stream result, chat scene"""

View File

@@ -1,11 +1,21 @@
import logging
logger = logging.getLogger(__name__)
def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch
_clear_torch_cache(device)
except ImportError:
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache
def _clear_torch_cache(device="cuda"):
try:
import torch
except ImportError:
return
import torch
import gc
gc.collect()
@@ -16,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")
logger.info("No cuda or mps, not support clear torch cache yet")