mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-29 21:40:44 +00:00
28 lines
853 B
Python
28 lines
853 B
Python
import logging
|
|
|
|
|
|
def _clear_torch_cache(device="cuda"):
|
|
import gc
|
|
|
|
import torch
|
|
|
|
gc.collect()
|
|
if device != "cpu":
|
|
if torch.has_mps:
|
|
try:
|
|
from torch.mps import empty_cache
|
|
|
|
empty_cache()
|
|
except Exception as e:
|
|
logging.warn(f"Clear mps torch cache error, {str(e)}")
|
|
elif torch.has_cuda:
|
|
device_count = torch.cuda.device_count()
|
|
for device_id in range(device_count):
|
|
cuda_device = f"cuda:{device_id}"
|
|
logging.info(f"Clear torch cache of device: {cuda_device}")
|
|
with torch.cuda.device(cuda_device):
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
else:
|
|
logging.info("No cuda or mps, not support clear torch cache yet")
|