mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-25 01:40:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			46 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			46 lines
		
	
	
		
			1.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| def set_to_cuda(models):
 | |
|     """Send model to gpu.
 | |
| 
 | |
|     :param models: nn.module or a list of module
 | |
|     """
 | |
|     if isinstance(models, list) and len(models) > 1:
 | |
|         ret = []
 | |
|         for model in models:
 | |
|             ret.append(model.to(get_current_device()))
 | |
|         return ret
 | |
|     elif isinstance(models, list):
 | |
|         return models[0].to(get_current_device())
 | |
|     else:
 | |
|         return models.to(get_current_device())
 | |
| 
 | |
| 
 | |
| def get_current_device():
 | |
|     """Returns the index of a currently selected device (gpu/cpu).
 | |
|     """
 | |
|     if torch.cuda.is_available():
 | |
|         return torch.cuda.current_device()
 | |
|     else:
 | |
|         return 'cpu'
 | |
| 
 | |
| 
 | |
| def synchronize():
 | |
|     """Similar to cuda.synchronize().
 | |
|     Waits for all kernels in all streams on a CUDA device to complete.
 | |
|     """
 | |
|     if torch.cuda.is_available():
 | |
|         torch.cuda.synchronize()
 | |
| 
 | |
| 
 | |
| def empty_cache():
 | |
|     """Similar to cuda.empty_cache()
 | |
|     Releases all unoccupied cached memory currently held by the caching allocator.
 | |
|     """
 | |
|     if torch.cuda.is_available():
 | |
|         torch.cuda.empty_cache()
 |