[zero] improve the accuracy of get_memory_usage of sharded param (#538)

This commit is contained in:
Jiarui Fang
2022-03-28 16:19:19 +08:00
committed by GitHub
parent 37cb70feec
commit a590ed0ba3
2 changed files with 39 additions and 8 deletions

View File

@@ -60,8 +60,24 @@ class ShardedParamV2(object):
elif t.device.type == 'cuda':
cuda_mem_use += t.numel() * t.element_size()
address_set = set()
_update_mem_use(self.sharded_data_tensor.payload)
_update_mem_use(self.fp16_grad)
_update_mem_use(self.fp32_grad)
address_set.add(self.sharded_data_tensor.payload.data_ptr())
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
_update_mem_use(self.fp16_grad)
address_set.add(self.fp16_grad.data_ptr())
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
_update_mem_use(self.fp32_grad)
address_set.add(self.fp32_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
_update_mem_use(self.param.data)
address_set.add(self.param.data.data_ptr())
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
_update_mem_use(self.param.grad)
address_set.add(self.param.grad.data_ptr())
return cuda_mem_use, cpu_mem_use