mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	[zero] get memory usage for sharded param (#536)
This commit is contained in:
		@@ -54,6 +54,24 @@ def _run_shard_param_v2(rank, world_size, port):
 | 
			
		||||
    sparam.remove_torch_payload()
 | 
			
		||||
    assert (param.data.numel() == 1)
 | 
			
		||||
 | 
			
		||||
    # Test get memory usage
 | 
			
		||||
    sparam.fp32_grad = torch.randn(2, 3)
 | 
			
		||||
    cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
 | 
			
		||||
    assert cpu_mem_use == 2 * 3 * 4 * 2
 | 
			
		||||
 | 
			
		||||
    sparam.fp16_grad = torch.randn(2, 3).cuda().half()
 | 
			
		||||
    cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
 | 
			
		||||
    assert cpu_mem_use == 2 * 3 * 4 * 2
 | 
			
		||||
    assert cuda_mem_use == 2 * 3 * 2
 | 
			
		||||
 | 
			
		||||
    sparam.fp16_grad = None
 | 
			
		||||
    sparam.fp32_grad = torch.randn(2, 3)
 | 
			
		||||
    sparam.remove_torch_payload()
 | 
			
		||||
    cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
 | 
			
		||||
    assert cpu_mem_use == 2 * 3 * 4 * 2
 | 
			
		||||
    assert cuda_mem_use == 0
 | 
			
		||||
    print(f'cuda_mem_use {cuda_mem_use} cpu_mem_use {cpu_mem_use}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.dist
 | 
			
		||||
@pytest.mark.parametrize("world_size", [1, 2])
 | 
			
		||||
@@ -64,5 +82,5 @@ def test_shard_param_v2(world_size):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    test_shard_tensor(2)
 | 
			
		||||
    # test_shard_tensor(2)
 | 
			
		||||
    test_shard_param_v2(2)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user