diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index 2078c55e7..e72e9ba31 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.parameter import ModelParameters from pilot.model.cluster.worker_base import ModelWorker from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter -from pilot.utils.model_utils import _clear_torch_cache +from pilot.utils.model_utils import _clear_model_cache from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger(__name__) @@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker): del self.tokenizer self.model = None self.tokenizer = None - _clear_torch_cache(self._model_params.device) + _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: torch_imported = False diff --git a/pilot/model/cluster/worker/embedding_worker.py b/pilot/model/cluster/worker/embedding_worker.py index a9f934a1c..62b799864 100644 --- a/pilot/model/cluster/worker/embedding_worker.py +++ b/pilot/model/cluster/worker/embedding_worker.py @@ -11,7 +11,7 @@ from pilot.model.parameter import ( ) from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.embedding.loader import EmbeddingLoader -from pilot.utils.model_utils import _clear_torch_cache +from pilot.utils.model_utils import _clear_model_cache from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger(__name__) @@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker): return del self._embeddings_impl self._embeddings_impl = None - _clear_torch_cache(self._model_params.device) + _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict): """Generate stream result, chat scene""" diff --git a/pilot/utils/model_utils.py b/pilot/utils/model_utils.py index 037e3a021..d9527118e 100644 --- a/pilot/utils/model_utils.py +++ b/pilot/utils/model_utils.py @@ -1,11 +1,21 @@ import logging +logger = logging.getLogger(__name__) + + +def _clear_model_cache(device="cuda"): + try: + # clear torch cache + import torch + + _clear_torch_cache(device) + except ImportError: + logger.warn("Torch not installed, skip clear torch cache") + # TODO clear other cache + def _clear_torch_cache(device="cuda"): - try: - import torch - except ImportError: - return + import torch import gc gc.collect() @@ -16,14 +26,14 @@ def _clear_torch_cache(device="cuda"): empty_cache() except Exception as e: - logging.warn(f"Clear mps torch cache error, {str(e)}") + logger.warn(f"Clear mps torch cache error, {str(e)}") elif torch.has_cuda: device_count = torch.cuda.device_count() for device_id in range(device_count): cuda_device = f"cuda:{device_id}" - logging.info(f"Clear torch cache of device: {cuda_device}") + logger.info(f"Clear torch cache of device: {cuda_device}") with torch.cuda.device(cuda_device): torch.cuda.empty_cache() torch.cuda.ipc_collect() else: - logging.info("No cuda or mps, not support clear torch cache yet") + logger.info("No cuda or mps, not support clear torch cache yet")