Files
DB-GPT/pilot/utils/model_utils.py
2023-08-29 22:54:18 +08:00

28 lines
853 B
Python

import logging
def _clear_torch_cache(device="cuda"):
import gc
import torch
gc.collect()
if device != "cpu":
if torch.has_mps:
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
logging.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logging.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logging.info("No cuda or mps, not support clear torch cache yet")