[zero] get memory usage for sharded param (#536)

This commit is contained in:
Jiarui Fang
2022-03-28 15:01:21 +08:00
committed by GitHub
parent 56ad945797
commit 37cb70feec
2 changed files with 45 additions and 2 deletions

View File

@@ -1,7 +1,7 @@
import torch
import torch.distributed as dist
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional
from typing import Optional, Tuple
class ShardedParamV2(object):
@@ -40,3 +40,28 @@ class ShardedParamV2(object):
@property
def param_is_sharded(self):
return self._sharded_data_tensor.is_sharded
def get_memory_usage(self) -> Tuple[int, int]:
"""
get the memory usage of the param, including data and grad
Returns:
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
"""
cuda_mem_use, cpu_mem_use = 0, 0
def _update_mem_use(t: Optional[torch.Tensor]):
if t is None:
return
assert isinstance(t, torch.Tensor)
nonlocal cuda_mem_use
nonlocal cpu_mem_use
if t.device.type == 'cpu':
cpu_mem_use += t.numel() * t.element_size()
elif t.device.type == 'cuda':
cuda_mem_use += t.numel() * t.element_size()
_update_mem_use(self.sharded_data_tensor.payload)
_update_mem_use(self.fp16_grad)
_update_mem_use(self.fp32_grad)
return cuda_mem_use, cpu_mem_use