diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 11f738615..9ffdce311 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -8,12 +8,28 @@ from torch.testing import assert_close import colossalai from colossalai.accelerator import get_accelerator -from colossalai.testing import spawn +from colossalai.testing import parameterize, spawn from colossalai.testing.random import seed_all from colossalai.utils import conditional_context from colossalai.zero import LowLevelZeroOptimizer +def loose_close(a, b, dtype: torch.dtype = torch.float32): + rtol = None + atol = None + if dtype is torch.float16: + rtol = 5e-2 + atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 + + a = a.detach().to(dtype) + b = b.detach().to(dtype) + + assert_close(a, b, rtol=rtol, atol=atol) + + class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() @@ -26,7 +42,9 @@ class MlpModel(nn.Module): return x -def exam_zero_1_2_grad_acc(): +@parameterize("sub_dp_size", [1, 2]) +def exam_zero_1_2_grad_acc(sub_dp_size: int): + assert torch.distributed.get_world_size() % sub_dp_size == 0 local_rank = torch.distributed.get_rank() seed_all(2009) device = get_accelerator().get_current_device() @@ -37,10 +55,20 @@ def exam_zero_1_2_grad_acc(): zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero1_optimizer = LowLevelZeroOptimizer( - zero1_optimizer, overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, verbose=True + zero1_optimizer, + overlap_communication=True, + initial_scale=32, + clip_grad_norm=1.0, + verbose=True, + sub_dp_size=sub_dp_size, ) zero2_optimizer = LowLevelZeroOptimizer( - zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, clip_grad_norm=1.0 + zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=32, + clip_grad_norm=1.0, + sub_dp_size=sub_dp_size, ) # create data seed_all(2021 + local_rank) @@ -51,7 +79,7 @@ def exam_zero_1_2_grad_acc(): # zero-dp forward zero1_output = zero1_model(cur_data) zero2_output = zero2_model(cur_data) - assert torch.equal(zero1_output, zero2_output) + loose_close(zero1_output, zero2_output) # zero-dp backward zero1_optimizer.backward(zero1_output.sum().float()) @@ -66,10 +94,13 @@ def exam_zero_1_2_grad_acc(): # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - assert torch.equal(z1p.data, z2p.data) + loose_close(z1p.data, z2p.data) -def exam_zero_1_grad_acc(sync): +@parameterize("no_sync", [True, False]) +@parameterize("sub_dp_size", [1, 2]) +def exam_zero_1_grad_acc(no_sync: bool, sub_dp_size: int): + assert torch.distributed.get_world_size() % sub_dp_size == 0 local_rank = torch.distributed.get_rank() seed_all(2008) device = get_accelerator().get_current_device() @@ -89,7 +120,11 @@ def exam_zero_1_grad_acc(sync): # in `check_sharded_param_consistency.py`, we will test whether # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer( - zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, clip_grad_norm=1.0 + zero_optimizer, + overlap_communication=False, + reduce_bucket_size=262144, + clip_grad_norm=1.0, + sub_dp_size=sub_dp_size, ) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -108,38 +143,37 @@ def exam_zero_1_grad_acc(sync): # torch-ddp fwd and bwd with conditional_context(torch_model.no_sync(), no_sync): torch_output = torch_model(cur_data) - assert torch.equal(zero_output, torch_output) + loose_close(zero_output, torch_output) torch_output.sum().backward() if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - assert torch.equal(p.grad, z1p.grad) + loose_close(p.grad, z1p.grad) - fwd_bwd_func(sync, input_data1, sync) + fwd_bwd_func(no_sync, input_data1, no_sync) fwd_bwd_func(False, input_data2, False) + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) zero_optimizer.step() - torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) torch_optimizer.step() # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data))) - assert_close(p.data, z1p.data) + loose_close(p.data, z1p.data) def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") - exam_zero_1_grad_acc(sync=True) - exam_zero_1_grad_acc(sync=False) + exam_zero_1_grad_acc() exam_zero_1_2_grad_acc() @pytest.mark.dist def test_grad_accumulation(): - spawn(run_dist, 2) + spawn(run_dist, 4) if __name__ == "__main__": diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index e2196cfbf..855a4efb3 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -51,7 +51,8 @@ def split_ddp_grad(grad, world_size): return splited_grad -def exam_zero_1_2(): +@parameterize("sub_dp_size", [1, 2]) +def exam_zero_1_2(sub_dp_size: int): """ In this test, we want to test whether zero stage 1 and 2 deliver the same numerical results despite different communication @@ -62,6 +63,7 @@ def exam_zero_1_2(): pg: partition gradients and optimizer states """ + assert torch.distributed.get_world_size() % sub_dp_size == 0 local_rank = torch.distributed.get_rank() seed_all(2001) @@ -73,10 +75,10 @@ def exam_zero_1_2(): zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) zero1_optimizer = LowLevelZeroOptimizer( - zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True + zero1_optimizer, overlap_communication=True, initial_scale=128, verbose=True, sub_dp_size=sub_dp_size ) zero2_optimizer = LowLevelZeroOptimizer( - zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128 + zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=128, sub_dp_size=sub_dp_size ) # create data seed_all(2001 + local_rank) @@ -94,7 +96,7 @@ def exam_zero_1_2(): z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) for z1g, z2g in zip(z1g_list, z2g_list): - assert torch.equal(z1g, z2g) + loose_close(z1g, z2g) # step zero1_optimizer.step() @@ -102,12 +104,13 @@ def exam_zero_1_2(): # check updated param for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - assert torch.equal(z1p.data, z2p.data) + loose_close(z1p.data, z2p.data) @parameterize("dtype", [torch.float16, torch.bfloat16]) @parameterize("master_weights", [True, False]) -def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): +@parameterize("sub_dp_size", [1, 2]) +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool, sub_dp_size: int): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -116,6 +119,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): We feed these two sets of models with the same input and check if the differences in model output and updated parameters are within tolerance. """ + assert world_size % sub_dp_size == 0 local_rank = torch.distributed.get_rank() seed_all(1453) @@ -137,6 +141,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): initial_scale=1, reduce_bucket_size=1024 * 1024, master_weights=master_weights, + sub_dp_size=sub_dp_size, ) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) @@ -162,7 +167,7 @@ def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype, master_weights: bool): for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): if p.grad is not None: zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) - torch_grad_list = split_ddp_grad(p.grad, world_size) + torch_grad_list = split_ddp_grad(p.grad, world_size // sub_dp_size) for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): loose_close(zero_grad, torch_grad, dtype=dtype) @@ -187,7 +192,7 @@ def run_dist(rank, world_size, port): @pytest.mark.dist @rerun_if_address_is_in_use() def test_zero_1_2(): - spawn(run_dist, 2) + spawn(run_dist, 4) if __name__ == "__main__":