mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[zero] improve the accuracy of get_memory_usage of sharded param (#538)
This commit is contained in:
@@ -60,8 +60,24 @@ class ShardedParamV2(object):
|
||||
elif t.device.type == 'cuda':
|
||||
cuda_mem_use += t.numel() * t.element_size()
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
_update_mem_use(self.fp16_grad)
|
||||
_update_mem_use(self.fp32_grad)
|
||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||
|
||||
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp16_grad)
|
||||
address_set.add(self.fp16_grad.data_ptr())
|
||||
|
||||
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp32_grad)
|
||||
address_set.add(self.fp32_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.data)
|
||||
address_set.add(self.param.data.data_ptr())
|
||||
|
||||
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.grad)
|
||||
address_set.add(self.param.grad.data_ptr())
|
||||
|
||||
return cuda_mem_use, cpu_mem_use
|
||||
|
||||
Reference in New Issue
Block a user