mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[zero] get memory usage for sharded param (#536)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
@@ -40,3 +40,28 @@ class ShardedParamV2(object):
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._sharded_data_tensor.is_sharded
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
get the memory usage of the param, including data and grad
|
||||
Returns:
|
||||
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
cuda_mem_use, cpu_mem_use = 0, 0
|
||||
|
||||
def _update_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor)
|
||||
nonlocal cuda_mem_use
|
||||
nonlocal cpu_mem_use
|
||||
if t.device.type == 'cpu':
|
||||
cpu_mem_use += t.numel() * t.element_size()
|
||||
elif t.device.type == 'cuda':
|
||||
cuda_mem_use += t.numel() * t.element_size()
|
||||
|
||||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
_update_mem_use(self.fp16_grad)
|
||||
_update_mem_use(self.fp32_grad)
|
||||
|
||||
return cuda_mem_use, cpu_mem_use
|
||||
|
Reference in New Issue
Block a user