refactor: The first refactored version for sdk release (#907)

Co-authored-by: chengfangyin2 <chengfangyin3@jd.com>
This commit is contained in:
FangYin Cheng
2023-12-08 14:45:59 +08:00
committed by GitHub
parent e7e4aff667
commit cd725db1fb
573 changed files with 2094 additions and 3571 deletions

84
dbgpt/util/model_utils.py Normal file
View File

@@ -0,0 +1,84 @@
from typing import List, Tuple
from dataclasses import dataclass
import logging
logger = logging.getLogger(__name__)
def _clear_model_cache(device="cuda"):
try:
# clear torch cache
import torch
_clear_torch_cache(device)
except ImportError:
logger.warn("Torch not installed, skip clear torch cache")
# TODO clear other cache
def _clear_torch_cache(device="cuda"):
import torch
import gc
gc.collect()
if device != "cpu":
if torch.has_mps:
try:
from torch.mps import empty_cache
empty_cache()
except Exception as e:
logger.warn(f"Clear mps torch cache error, {str(e)}")
elif torch.has_cuda:
device_count = torch.cuda.device_count()
for device_id in range(device_count):
cuda_device = f"cuda:{device_id}"
logger.info(f"Clear torch cache of device: {cuda_device}")
with torch.cuda.device(cuda_device):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
else:
logger.info("No cuda or mps, not support clear torch cache yet")
@dataclass
class GPUInfo:
total_memory_gb: float
allocated_memory_gb: float
cached_memory_gb: float
available_memory_gb: float
def _get_current_cuda_memory() -> List[GPUInfo]:
try:
import torch
except ImportError:
logger.warn("Torch not installed")
return []
if torch.cuda.is_available():
num_gpus = torch.cuda.device_count()
gpu_infos = []
for gpu_id in range(num_gpus):
with torch.cuda.device(gpu_id):
device = torch.cuda.current_device()
gpu_properties = torch.cuda.get_device_properties(device)
total_memory = round(gpu_properties.total_memory / (1.0 * 1024**3), 2)
allocated_memory = round(
torch.cuda.memory_allocated() / (1.0 * 1024**3), 2
)
cached_memory = round(
torch.cuda.memory_reserved() / (1.0 * 1024**3), 2
)
available_memory = total_memory - allocated_memory
gpu_infos.append(
GPUInfo(
total_memory_gb=total_memory,
allocated_memory_gb=allocated_memory,
cached_memory_gb=cached_memory,
available_memory_gb=available_memory,
)
)
return gpu_infos
else:
logger.warn("CUDA is not available.")
return []