[zero] get memory usage for sharded param (#536)

This commit is contained in:
Jiarui Fang
2022-03-28 15:01:21 +08:00
committed by GitHub
parent 56ad945797
commit 37cb70feec
2 changed files with 45 additions and 2 deletions

View File

@@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port):
sparam.remove_torch_payload()
assert (param.data.numel() == 1)
# Test get memory usage
sparam.fp32_grad = torch.randn(2, 3)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = None
sparam.fp32_grad = torch.randn(2, 3)
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0
print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@@ -64,5 +82,5 @@ def test_shard_param_v2(world_size):
if __name__ == '__main__':
test_shard_tensor(2)
# test_shard_tensor(2)
test_shard_param_v2(2)