mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-15 14:47:16 +00:00
[zero] get memory usage for sharded param (#536)
This commit is contained in:
@@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 1)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.fp32_grad = torch.randn(2, 3)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
|
||||
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = None
|
||||
sparam.fp32_grad = torch.randn(2, 3)
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@@ -64,5 +82,5 @@ def test_shard_param_v2(world_size):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_tensor(2)
|
||||
# test_shard_tensor(2)
|
||||
test_shard_param_v2(2)
|
||||
|
||||
Reference in New Issue
Block a user