mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-19 00:14:40 +00:00
chore: Add _clear_model_cache function to clear model cache
This commit is contained in:
@@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
|
||||
del self.tokenizer
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
torch_imported = False
|
||||
|
@@ -11,7 +11,7 @@ from pilot.model.parameter import (
|
||||
)
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||
from pilot.utils.model_utils import _clear_torch_cache
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
||||
return
|
||||
del self._embeddings_impl
|
||||
self._embeddings_impl = None
|
||||
_clear_torch_cache(self._model_params.device)
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict):
|
||||
"""Generate stream result, chat scene"""
|
||||
|
@@ -1,11 +1,21 @@
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _clear_model_cache(device="cuda"):
|
||||
try:
|
||||
# clear torch cache
|
||||
import torch
|
||||
|
||||
_clear_torch_cache(device)
|
||||
except ImportError:
|
||||
logger.warn("Torch not installed, skip clear torch cache")
|
||||
# TODO clear other cache
|
||||
|
||||
|
||||
def _clear_torch_cache(device="cuda"):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return
|
||||
import torch
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
@@ -16,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
|
||||
|
||||
empty_cache()
|
||||
except Exception as e:
|
||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||
elif torch.has_cuda:
|
||||
device_count = torch.cuda.device_count()
|
||||
for device_id in range(device_count):
|
||||
cuda_device = f"cuda:{device_id}"
|
||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
||||
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||
with torch.cuda.device(cuda_device):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
else:
|
||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
||||
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||
|
Reference in New Issue
Block a user