diff --git a/colossalai/testing/random.py b/colossalai/testing/random.py new file mode 100644 index 000000000..ad6d24a4b --- /dev/null +++ b/colossalai/testing/random.py @@ -0,0 +1,19 @@ +import random + +import numpy as np +import torch + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/sharded_optim/_utils.py index 49cf21969..9a839a570 100644 --- a/colossalai/zero/sharded_optim/_utils.py +++ b/colossalai/zero/sharded_optim/_utils.py @@ -1,11 +1,13 @@ import math + import torch +import torch.distributed as dist from torch._six import inf from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from colossalai.core import global_context as gpc + from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.utils import is_model_parallel_parameter -import torch.distributed as dist def flatten(input_): @@ -99,19 +101,24 @@ def split_half_float_double(tensor_list): return buckets -def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA): +def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA): """ Reduce the tensor in the data parallel process group :param tensor: A tensor object to reduce/all-reduce :param dtype: The data type used in communication :param dst_rank: The source rank for reduce. If dst_rank is None, + :param parallel_mode: Communication parallel mode all-reduce will be used instead of reduce. Default is None. :type tensor: torch.Tensor - :type dtype: torch.dtype + :type dtype: torch.dtype, optional :type dst_rank: int, optional + :type parallel_mode: ParallelMode, optional """ + # use the original dtype + if dtype is None: + dtype = tensor.dtype # cast the data to specified dtype for reduce/all-reduce if tensor.dtype != dtype: @@ -139,6 +146,7 @@ def reduce_tensor(tensor, dtype, dst_rank=None, parallel_mode=ParallelMode.DATA) local_rank = gpc.get_local_rank(parallel_mode) if use_all_reduce or dst_rank == local_rank: tensor.copy_(tensor_to_reduce) + return tensor @@ -238,7 +246,7 @@ def sync_param(flat_tensor, tensor_list): Synchronize the flattened tensor and unflattened tensor list. When a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, a new tensor is created. Thus, the flat tensor and original tensor list do not - share the same memory space. This function will update the tensor list so that + share the same memory space. This function will update the tensor list so that they point to the same value. :param flat_tensor: A flat tensor obtained by calling `torch._utils._unflatten_dense_tensors` on a tensor lsit diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index 86e39077d..d30b69e7e 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -44,12 +44,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): max_scale: int = 2**32, # grad clipping - clip_grad_norm=2.0, + clip_grad_norm=0.0, verbose=False, # communication - reduce_bucket_size=50000000, - communication_dtype=torch.float16, + reduce_bucket_size=1024 * 1024, + communication_dtype=None, overlap_communication=False, # stage 2 @@ -58,7 +58,10 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): mp_parallel_mode=ParallelMode.MODEL, # cpu offload - cpu_offload=False): + cpu_offload=False, + + # forced dtype + forced_dtype=None): # TODO: add support for # 1. fp16 master weights @@ -112,6 +115,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # gradient clipping self._clip_grad_norm = clip_grad_norm + if forced_dtype: + for group in self._optimizer.param_groups: + group_params = group['params'] + for param in group_params: + param.data = param.data.to(forced_dtype) + self._dtype = forced_dtype + # check argument conflict self._sanity_checks() @@ -225,17 +235,21 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): fp32_partition_grad = torch.zeros_like(fp32_partition_param) fp32_partition_param.grad = fp32_partition_grad + # we do not need log information for optimizer, so comment them # update the parameter with zero gradients for initialization of optimizer states - self._optimizer.step() + # self._optimizer.step() # remove the grad of the paramter to save memory - for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): - fp32_flat_tensor.grad = None + # for group_id, fp32_flat_tensor in self._fp32_flat_param_groups_of_current_rank.items(): + # fp32_flat_tensor.grad = None def _sanity_checks(self): assert torch.cuda.is_available(), 'CUDA is required' - assert self._dtype == torch.float16, \ - f'Parameters are expected to be of type torch.float16, but got {self._dtype}' + for param_group in self._optimizer.param_groups: + group_params = param_group['params'] + for param in group_params: + assert param.dtype == self._dtype, \ + f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" ########################################################### # Backward Reduction Hook @@ -389,6 +403,18 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): loss = self.loss_scale * loss loss.backward(retain_graph=retain_graph) + # finish gradient reduction + if not self._partition_grads: + self._reduce_grad_stage1() + else: + # TODO: support async comm in reduce + self._reduce_grad_stage2() + + # clear reduced grads + if self._overlap_communication: + torch.cuda.synchronize() + self._param_store.clear_grads_of_previous_reduced_params() + def zero_grad(self, set_to_none=True): """ Set parameter gradients to zero. If set_to_none = True, gradient @@ -465,7 +491,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # update fp16 partition updated by the current rank for group_id in range(len(self._fp16_param_groups)): fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=self._local_rank, group_id=group_id) - fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id].to(fp16_param.device) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp16_param.data.copy_(fp32_param) # broadcast the updated model weights @@ -524,22 +550,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ############################ def sync_grad(self): - if not self._partition_grads: - self._reduce_grad_stage1() - else: - # TODO: support async comm in reduce - self._reduce_grad_stage2() - # update param already reduced flag reduction_states = self._param_store.get_param_reduction_states() for tensor, state in reduction_states.items(): reduction_states[tensor] = False - # clear reduced grads - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - # accumulate gradient avg_gradients = self._grad_store._averaged_gradients for group_id in range(self.num_param_groups): diff --git a/tests/test_zero/low_level_zero/test_grad_acc.py b/tests/test_zero/low_level_zero/test_grad_acc.py new file mode 100644 index 000000000..c23b3a3e8 --- /dev/null +++ b/tests/test_zero/low_level_zero/test_grad_acc.py @@ -0,0 +1,167 @@ +import copy +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close + +import colossalai +from colossalai.testing.random import seed_all +from colossalai.utils import free_port +from colossalai.zero import LowLevelZeroOptimizer + + +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +def exam_zero_1_2_grad_acc(): + local_rank = torch.distributed.get_rank() + seed_all(2009) + + # create model + zero1_model = TestModel().cuda() + zero2_model = copy.deepcopy(zero1_model) + + # create optimizer + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=32, + clip_grad_norm=1.0, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=32, + clip_grad_norm=1.0) + # create data + seed_all(2021 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data): + # zero-dp forward + zero1_output = zero1_model(cur_data) + zero2_output = zero2_model(cur_data) + assert torch.equal(zero1_output, zero2_output) + + # zero-dp backward + zero1_optimizer.backward(zero1_output.sum().float()) + zero2_optimizer.backward(zero2_output.sum().float()) + + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) + + zero1_optimizer.sync_grad() + zero2_optimizer.sync_grad() + + fwd_bwd_func(0, input_data1) + fwd_bwd_func(1, input_data2) + + # step + zero1_optimizer.step() + zero2_optimizer.step() + + # check updated param + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) + + +def exam_zero_1_grad_acc(): + local_rank = torch.distributed.get_rank() + grad_scale = 32 + seed_all(2008) + + # create models + zero_model = TestModel() + torch_model = copy.deepcopy(zero_model) + + zero_model = zero_model.cuda() + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1) + + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, + overlap_communication=False, + initial_scale=grad_scale, + reduce_bucket_size=262144, + clip_grad_norm=1.0) + + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) + + # create data + seed_all(2022 + local_rank) + input_data1 = torch.randn(32, 128).cuda() + input_data2 = torch.randn(32, 128).cuda() + + def fwd_bwd_func(number, cur_data, check_flag): + # zero-dp forward + zero_output = zero_model(cur_data) + + # torch-ddp forward + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + + # zero-dp backward + zero_optimizer.backward(zero_output.sum().float()) + # torch-ddp backward + torch_output.sum().backward() + + if check_flag: + # check grad + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + unscale_grad = z1p.grad / grad_scale + # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) + assert torch.equal(p.grad, unscale_grad) + + zero_optimizer.sync_grad() + + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) + + zero_optimizer.step() + torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) + torch_optimizer.step() + + # check updated param + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, p.shape, torch.max(p.data), torch.max(z1p.data), torch.max(torch.abs(p.data - z1p.data))) + assert_close(p.data, z1p.data) + + +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + + exam_zero_1_grad_acc() + # exam_zero_1_2_grad_acc() + + +@pytest.mark.dist +def test_grad_accumulation(): + world_size = 2 + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_grad_accumulation() diff --git a/tests/test_zero/low_level_zero/test_grad_clip.py b/tests/test_zero/low_level_zero/test_grad_clip.py deleted file mode 100644 index a6959352c..000000000 --- a/tests/test_zero/low_level_zero/test_grad_clip.py +++ /dev/null @@ -1,161 +0,0 @@ -import copy -from functools import partial - -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.utils import free_port -from colossalai.zero import LowLevelZeroOptimizer - - -def check_equal(a, b, rtol=1e-4, atol=1e-3): - """ - This function checks if two tensors are equal within tolerance - """ - assert torch.allclose(a.float(), b.float(), rtol=rtol, atol=atol), f'a = {a}, b = {b}' - - -def check_completely_equal(a, b): - """ - This function checks if two tensors are completely equal - """ - assert torch.all(a == b), f'a = {a}, b = {b}' - - -class TestModel(nn.Module): - - def __init__(self): - super(TestModel, self).__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - -def exam_zero_1_2_grad_clip(): - # create model - zero1_model = TestModel().cuda().half() - zero2_model = copy.deepcopy(zero1_model) - - # create optimizer - zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=0.001) - zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=0.001) - zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, - overlap_communication=True, - initial_scale=32, - clip_grad_norm=1.0, - verbose=True) - zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=32, - clip_grad_norm=1.0) - - # create - input_data = torch.rand(32, 128).cuda().half() - - # forward - zero1_output = zero1_model(input_data) - zero2_output = zero2_model(input_data) - check_completely_equal(zero1_output, zero2_output) - - # backward - zero1_optimizer.backward(zero1_output.mean().float()) - zero2_optimizer.backward(zero2_output.mean().float()) - - # check grad - # as this param is small, the backward reduction - # will not be fired - for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - check_completely_equal(z1p.grad, z2p.grad) - - # step - zero1_optimizer.sync_grad() - zero2_optimizer.sync_grad() - - # step - zero1_optimizer.step() - zero2_optimizer.step() - - # check updated param - for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): - check_completely_equal(z1p.data, z2p.data) - - -def exam_zero_1_grad_clip(): - # create models - zero_model = TestModel() - torch_model = copy.deepcopy(zero_model) - - zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda()) - - # create optimizer - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) - - # we only test stage 1 here - # in `check_sharded_param_consistency.py`, we will test whether - # level 1 and 2 will produce exactly the same results - zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, - overlap_communication=True, - initial_scale=1, - clip_grad_norm=1.0) - - torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) - - # create - input_data = torch.rand(32, 128).cuda() - - # zero-dp forward - zero_output = zero_model(input_data.half()) - - # torch-ddp forward - torch_output = torch_model(input_data) - check_equal(zero_output, torch_output) - - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) - - # torch-ddp backward - torch_output.mean().backward() - - # check grad - for p, z1p in zip(torch_model.parameters(), zero_model.parameters()): - check_equal(p.grad, z1p.grad) - - # zero-dp step - zero_optimizer.sync_grad() - zero_optimizer.step() - - # torch ddp step - torch.nn.utils.clip_grad_norm_(torch_model.parameters(), 1.0) - torch_optimizer.step() - - # check updated param - for p, z1p in zip(torch_model.parameters(), zero_model.parameters()): - check_equal(p.data, z1p.data, atol=5e-4) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - - exam_zero_1_2_grad_clip() - exam_zero_1_grad_clip() - - -@pytest.mark.dist -def test_grad_clip(): - world_size = 2 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_grad_clip() diff --git a/tests/test_zero/low_level_zero/test_zero1_2.py b/tests/test_zero/low_level_zero/test_zero1_2.py index 8a510daaf..b02d3a6a4 100644 --- a/tests/test_zero/low_level_zero/test_zero1_2.py +++ b/tests/test_zero/low_level_zero/test_zero1_2.py @@ -6,27 +6,41 @@ import torch import torch.multiprocessing as mp import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP +from torch.testing import assert_close import colossalai +from colossalai.testing.random import seed_all from colossalai.utils import free_port from colossalai.zero import LowLevelZeroOptimizer -def check_equal(a, b): - """ - This function checks if two tensors are equal within tolerance - """ - assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' +class TestModel(nn.Module): + + def __init__(self): + super(TestModel, self).__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x -def check_completely_equal(a, b): - """ - This function checks if two tensors are completely equal - """ - assert torch.all(a == b), f'a = {a}, b = {b}' +def half_close(a, b, loose=False): + rtol = None + atol = None + if loose: + rtol = 5e-2 + atol = 5e-4 + + a = a.detach().half() + b = b.detach().half() + + assert_close(a, b, rtol=rtol, atol=atol) -def check_sharded_param_consistency(): +def exam_zero_1_2(): """ In this test, we want to test whether zero stage 1 and 2 deliver the same numerical results despite different communication @@ -37,67 +51,54 @@ def check_sharded_param_consistency(): pg: partition gradients and optimizer states """ - - # create layers - oss_linear1 = nn.Linear(128, 256) - oss_linear2 = nn.Linear(256, 512) + local_rank = torch.distributed.get_rank() + seed_all(2001) # create model - oss_model = nn.Sequential(oss_linear1, oss_linear2) - pg_model = copy.deepcopy(oss_model) - - oss_model = oss_model.cuda().half() - pg_model = pg_model.cuda().half() + zero1_model = TestModel().cuda() + zero2_model = copy.deepcopy(zero1_model) # create optimizer - oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) - pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) - oss_optimizer = LowLevelZeroOptimizer(oss_optimizer, - overlap_communication=True, - initial_scale=1, - clip_grad_norm=0.0) - pg_optimizer = LowLevelZeroOptimizer(pg_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=1, - clip_grad_norm=0.0) + zero1_optimizer = torch.optim.Adam(zero1_model.parameters(), lr=1) + zero2_optimizer = torch.optim.Adam(zero2_model.parameters(), lr=1) + zero1_optimizer = LowLevelZeroOptimizer(zero1_optimizer, + overlap_communication=True, + initial_scale=128, + verbose=True) + zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=128) + # create data + seed_all(2001 + local_rank) + input_data = torch.randn(32, 128).cuda() - # create - input_data = torch.rand(32, 128).cuda().half() + zero1_output = zero1_model(input_data) + zero2_output = zero2_model(input_data) + assert torch.equal(zero1_output, zero2_output) - # forward - oss_output = oss_model(input_data) - pg_output = pg_model(input_data) - check_completely_equal(oss_output, pg_output) + # zero-dp backward + zero1_optimizer.backward(zero1_output.mean().float()) + zero2_optimizer.backward(zero2_output.mean().float()) - # backward - oss_optimizer.backward(oss_output.mean().float()) - pg_optimizer.backward(pg_output.mean().float()) + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) - # check grad - # as this param is small, the backward reduction - # will not be fired - oss_linear1_grad = oss_model[0].weight.grad - oss_linear2_grad = oss_model[1].weight.grad - pg_linear1_grad = pg_model[0].weight.grad - pg_linear2_grad = pg_model[1].weight.grad - check_completely_equal(oss_linear1_grad, pg_linear1_grad) - check_completely_equal(oss_linear2_grad, pg_linear2_grad) + zero1_optimizer.sync_grad() + zero2_optimizer.sync_grad() # step - oss_optimizer.sync_grad() - pg_optimizer.sync_grad() - - # step - oss_optimizer.step() - pg_optimizer.step() + zero1_optimizer.step() + zero2_optimizer.step() # check updated param - check_completely_equal(oss_model[0].weight, pg_model[0].weight) - check_completely_equal(oss_model[1].weight, pg_model[1].weight) + for z1p, z2p in zip(zero1_model.parameters(), zero2_model.parameters()): + assert torch.equal(z1p.data, z2p.data) -def check_sharded_optim_against_torch_ddp(): +def exam_zero_1_torch_ddp(): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -106,20 +107,22 @@ def check_sharded_optim_against_torch_ddp(): We feed these two sets of models with the same input and check if the differences in model output and updated parameters are within tolerance. """ + local_rank = torch.distributed.get_rank() + seed_all(1453) - # create layer - zero_linear1 = nn.Linear(128, 256) - zero_linear2 = nn.Linear(256, 512) - - # create model - zero_model = nn.Sequential(zero_linear1, zero_linear2) + # create models + zero_model = TestModel() torch_model = copy.deepcopy(zero_model) zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda()) + # torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) + torch_model = torch_model.cuda() + + # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # half_close(p.data, z1p.data) # create optimizer - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) + zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) # we only test stage 1 here # in `check_sharded_param_consistency.py`, we will test whether @@ -127,10 +130,11 @@ def check_sharded_optim_against_torch_ddp(): zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=True, initial_scale=1, - clip_grad_norm=0.0) + reduce_bucket_size=262144) - torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) + torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) + seed_all(1453 + local_rank) # create input_data = torch.rand(32, 128).cuda() @@ -139,7 +143,7 @@ def check_sharded_optim_against_torch_ddp(): # torch-ddp forward torch_output = torch_model(input_data) - check_equal(zero_output, torch_output) + half_close(zero_output, torch_output, loose=True) # zero-dp backward zero_optimizer.backward(zero_output.mean().float()) @@ -148,12 +152,8 @@ def check_sharded_optim_against_torch_ddp(): torch_output.mean().backward() # check grad - zero_linear1_grad = zero_model[0].weight.grad - zero_linear2_grad = zero_model[1].weight.grad - torch_linear1_grad = torch_model.module[0].weight.grad - torch_linear2_grad = torch_model.module[1].weight.grad - check_equal(zero_linear1_grad, torch_linear1_grad) - check_equal(zero_linear2_grad, torch_linear2_grad) + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + half_close(p.grad, z1p.grad, loose=True) # zero-dp step zero_optimizer.sync_grad() @@ -163,23 +163,24 @@ def check_sharded_optim_against_torch_ddp(): torch_optimizer.step() # check updated param - check_equal(zero_model[0].weight, torch_model.module[0].weight) - check_equal(zero_model[1].weight, torch_model.module[1].weight) + for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): + # print(n, torch.max(torch.abs(p.data - z1p.data))) + half_close(p.data, z1p.data, loose=True) def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - check_sharded_optim_against_torch_ddp() - check_sharded_param_consistency() + exam_zero_1_torch_ddp() + exam_zero_1_2() @pytest.mark.dist -def test_sharded_optim(): +def test_zero_1_2(): world_size = 2 run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': - test_sharded_optim() + test_zero_1_2()