mirror of
				https://github.com/csunny/DB-GPT.git
				synced 2025-10-26 20:29:34 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			86 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			86 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import logging
 | |
| from dataclasses import dataclass
 | |
| from typing import List, Tuple
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| 
 | |
| def _clear_model_cache(device="cuda"):
 | |
|     try:
 | |
|         # clear torch cache
 | |
|         import torch
 | |
| 
 | |
|         _clear_torch_cache(device)
 | |
|     except ImportError:
 | |
|         logger.warn("Torch not installed, skip clear torch cache")
 | |
|     # TODO clear other cache
 | |
| 
 | |
| 
 | |
| def _clear_torch_cache(device="cuda"):
 | |
|     import gc
 | |
| 
 | |
|     import torch
 | |
| 
 | |
|     gc.collect()
 | |
|     if device != "cpu":
 | |
|         if torch.has_mps:
 | |
|             try:
 | |
|                 from torch.mps import empty_cache
 | |
| 
 | |
|                 empty_cache()
 | |
|             except Exception as e:
 | |
|                 logger.warn(f"Clear mps torch cache error, {str(e)}")
 | |
|         elif torch.has_cuda:
 | |
|             device_count = torch.cuda.device_count()
 | |
|             for device_id in range(device_count):
 | |
|                 cuda_device = f"cuda:{device_id}"
 | |
|                 logger.info(f"Clear torch cache of device: {cuda_device}")
 | |
|                 with torch.cuda.device(cuda_device):
 | |
|                     torch.cuda.empty_cache()
 | |
|                     torch.cuda.ipc_collect()
 | |
|         else:
 | |
|             logger.info("No cuda or mps, not support clear torch cache yet")
 | |
| 
 | |
| 
 | |
| @dataclass
 | |
| class GPUInfo:
 | |
|     total_memory_gb: float
 | |
|     allocated_memory_gb: float
 | |
|     cached_memory_gb: float
 | |
|     available_memory_gb: float
 | |
| 
 | |
| 
 | |
| def _get_current_cuda_memory() -> List[GPUInfo]:
 | |
|     try:
 | |
|         import torch
 | |
|     except ImportError:
 | |
|         logger.warn("Torch not installed")
 | |
|         return []
 | |
|     if torch.cuda.is_available():
 | |
|         num_gpus = torch.cuda.device_count()
 | |
|         gpu_infos = []
 | |
|         for gpu_id in range(num_gpus):
 | |
|             with torch.cuda.device(gpu_id):
 | |
|                 device = torch.cuda.current_device()
 | |
|                 gpu_properties = torch.cuda.get_device_properties(device)
 | |
|                 total_memory = round(gpu_properties.total_memory / (1.0 * 1024**3), 2)
 | |
|                 allocated_memory = round(
 | |
|                     torch.cuda.memory_allocated() / (1.0 * 1024**3), 2
 | |
|                 )
 | |
|                 cached_memory = round(
 | |
|                     torch.cuda.memory_reserved() / (1.0 * 1024**3), 2
 | |
|                 )
 | |
|                 available_memory = total_memory - allocated_memory
 | |
|                 gpu_infos.append(
 | |
|                     GPUInfo(
 | |
|                         total_memory_gb=total_memory,
 | |
|                         allocated_memory_gb=allocated_memory,
 | |
|                         cached_memory_gb=cached_memory,
 | |
|                         available_memory_gb=available_memory,
 | |
|                     )
 | |
|                 )
 | |
|         return gpu_infos
 | |
|     else:
 | |
|         logger.warn("CUDA is not available.")
 | |
|         return []
 |