mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
Migrated project
This commit is contained in:
48
colossalai/utils/cuda.py
Normal file
48
colossalai/utils/cuda.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def set_to_cuda(models):
|
||||
'''Send model to gpu.
|
||||
|
||||
:param models: nn.module or a list of module
|
||||
'''
|
||||
if isinstance(models, list) and len(models) > 1:
|
||||
ret = []
|
||||
for model in models:
|
||||
ret.append(model.to(get_current_device()))
|
||||
return ret
|
||||
elif isinstance(models, list):
|
||||
return models[0].to(get_current_device())
|
||||
else:
|
||||
return models.to(get_current_device())
|
||||
|
||||
|
||||
def get_current_device():
|
||||
'''
|
||||
Returns the index of a currently selected device (gpu/cpu).
|
||||
'''
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.current_device()
|
||||
else:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def synchronize():
|
||||
'''
|
||||
Similar to cuda.synchronize().
|
||||
Waits for all kernels in all streams on a CUDA device to complete.
|
||||
'''
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def empty_cache():
|
||||
'''
|
||||
Similar to cuda.empty_cache()
|
||||
Releases all unoccupied cached memory currently held by the caching allocator.
|
||||
'''
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
Reference in New Issue
Block a user