mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-19 16:44:31 +00:00
chore: Add _clear_model_cache function to clear model cache
This commit is contained in:
@@ -8,7 +8,7 @@ from pilot.model.loader import ModelLoader, _get_model_real_path
|
|||||||
from pilot.model.parameter import ModelParameters
|
from pilot.model.parameter import ModelParameters
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -87,7 +87,7 @@ class DefaultModelWorker(ModelWorker):
|
|||||||
del self.tokenizer
|
del self.tokenizer
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
_clear_torch_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||||
torch_imported = False
|
torch_imported = False
|
||||||
|
@@ -11,7 +11,7 @@ from pilot.model.parameter import (
|
|||||||
)
|
)
|
||||||
from pilot.model.cluster.worker_base import ModelWorker
|
from pilot.model.cluster.worker_base import ModelWorker
|
||||||
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
from pilot.model.cluster.embedding.loader import EmbeddingLoader
|
||||||
from pilot.utils.model_utils import _clear_torch_cache
|
from pilot.utils.model_utils import _clear_model_cache
|
||||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -79,7 +79,7 @@ class EmbeddingsModelWorker(ModelWorker):
|
|||||||
return
|
return
|
||||||
del self._embeddings_impl
|
del self._embeddings_impl
|
||||||
self._embeddings_impl = None
|
self._embeddings_impl = None
|
||||||
_clear_torch_cache(self._model_params.device)
|
_clear_model_cache(self._model_params.device)
|
||||||
|
|
||||||
def generate_stream(self, params: Dict):
|
def generate_stream(self, params: Dict):
|
||||||
"""Generate stream result, chat scene"""
|
"""Generate stream result, chat scene"""
|
||||||
|
@@ -1,11 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_model_cache(device="cuda"):
|
||||||
|
try:
|
||||||
|
# clear torch cache
|
||||||
|
import torch
|
||||||
|
|
||||||
|
_clear_torch_cache(device)
|
||||||
|
except ImportError:
|
||||||
|
logger.warn("Torch not installed, skip clear torch cache")
|
||||||
|
# TODO clear other cache
|
||||||
|
|
||||||
|
|
||||||
def _clear_torch_cache(device="cuda"):
|
def _clear_torch_cache(device="cuda"):
|
||||||
try:
|
import torch
|
||||||
import torch
|
|
||||||
except ImportError:
|
|
||||||
return
|
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -16,14 +26,14 @@ def _clear_torch_cache(device="cuda"):
|
|||||||
|
|
||||||
empty_cache()
|
empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
logger.warn(f"Clear mps torch cache error, {str(e)}")
|
||||||
elif torch.has_cuda:
|
elif torch.has_cuda:
|
||||||
device_count = torch.cuda.device_count()
|
device_count = torch.cuda.device_count()
|
||||||
for device_id in range(device_count):
|
for device_id in range(device_count):
|
||||||
cuda_device = f"cuda:{device_id}"
|
cuda_device = f"cuda:{device_id}"
|
||||||
logging.info(f"Clear torch cache of device: {cuda_device}")
|
logger.info(f"Clear torch cache of device: {cuda_device}")
|
||||||
with torch.cuda.device(cuda_device):
|
with torch.cuda.device(cuda_device):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
else:
|
else:
|
||||||
logging.info("No cuda or mps, not support clear torch cache yet")
|
logger.info("No cuda or mps, not support clear torch cache yet")
|
||||||
|
Reference in New Issue
Block a user