[refactor] refactor the memory utils (#715)

This commit is contained in:
Jiarui Fang
2022-04-11 16:47:57 +08:00
committed by GitHub
parent dbd96fe90a
commit 193dc8dacb
20 changed files with 218 additions and 308 deletions

View File

@@ -20,13 +20,15 @@ def set_to_cuda(models):
return models.to(get_current_device())
def get_current_device():
"""Returns the index of a currently selected device (gpu/cpu).
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.cuda.current_device()
return torch.device(f'cuda:{torch.cuda.current_device()}')
else:
return 'cpu'
return torch.device('cpu')
def synchronize():