[zero] trace states of fp16/32 grad and fp32 param (#571)

This commit is contained in:
ver217
2022-03-31 16:26:54 +08:00
committed by GitHub
parent 7675366fce
commit 7c6c427db1
5 changed files with 69 additions and 73 deletions

View File

@@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = StatefulTensor(None)
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()