diff --git a/colossalai/zero/low_level/_utils.py b/colossalai/zero/low_level/_utils.py index 218f7603b..a9e552ebd 100644 --- a/colossalai/zero/low_level/_utils.py +++ b/colossalai/zero/low_level/_utils.py @@ -253,7 +253,7 @@ def compute_norm(gradients, params, dp_group, mp_group, norm_type=2): return total_norm -def sync_param(flat_tensor, tensor_list): +def sync_tensor(flat_tensor, tensor_list): """ Synchronize the flattened tensor and unflattened tensor list. When a list of tensor are flattened with `torch._utils._unflatten_dense_tensors`, diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index ec322a78b..98f1b78d0 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -1,3 +1,8 @@ +from typing import Dict + +import torch +from torch import Tensor +from torch._utils import _flatten_dense_tensors from torch.distributed import ProcessGroup from .base_store import BaseStore @@ -7,35 +12,102 @@ class BucketStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - self._params = dict() - self._num_elements_in_bucket = dict() + + # init and reset + self.current_group_id = 0 + # mapping gardient slices and parameter + self.grad_to_param_mapping = dict() + + self._param_list = [] + self._padding_size = [] self.reset() - def num_elements_in_bucket(self, reduce_rank: int = None): - return self._num_elements_in_bucket[reduce_rank] + def num_elements_in_bucket(self) -> int: + """Return the total number of elements in bucket - def add_num_elements_in_bucket(self, num_elements, reduce_rank: int = None): - self._num_elements_in_bucket[reduce_rank] += num_elements + Returns: + int: the total number of elements in bucket + """ - def add_param(self, tensor, reduce_rank: int = None): - self._params[reduce_rank].append(tensor) + return self._num_elements_in_bucket + + def add_param_grad(self, group_id: int, param: Tensor, padding_size: int): + """Add a param to bucket and record the padding size of a param for gradient padding + + Args: + group_id (int): The index of a parameter group + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._param_list.append(param) + self._padding_size.append(padding_size) + self._num_elements_in_bucket += (param.numel() + padding_size) + self.current_group_id = group_id + + def build_grad_in_bucket(self): + """Orgnize parameters' gradient(padding and split), follows the paramters' splitting method + + Data structure of self._grad_in_bucket: + { + rank0: [grad0_rank0, grad1_rank0, ...] + rank1: [grad1_rank1, grad1_rank1, ...] + } + """ + + for param, padding_size in zip(self._param_list, self._padding_size): + with torch.no_grad(): + grad = param.grad.detach().flatten() + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + grad_list = grad.split(grad.numel() // self._world_size) + for rank in range(self._world_size): + grad_current_rank = grad_list[rank].detach() + self.grad_to_param_mapping[id(grad_current_rank)] = id(param) + self._grad_in_bucket[rank].append(grad_current_rank) + param.grad = None + + def get_grad(self) -> Dict: + """Return the dictionary of gradients slices, of which the keys are ranks + + Returns: + Dict: The dictionary of gradients slices + """ + + return self._grad_in_bucket + + def get_flatten_grad(self) -> Tensor: + """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: + [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + + Returns: + Tensor: the flattened gradients slices in the bucket + """ + + flat_grad = [] + for grad_list in self._grad_in_bucket.values(): + flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad = _flatten_dense_tensors(flat_grad) + return flat_grad + + def get_param_id_of_grad(self, grad: Tensor) -> int: + """Return the id of a parameter which the gradient slice belongs to + + Args: + grad (Tensor): the gradient slice + + Returns: + int: the id of a parameter which the gradient slice belongs to + """ + + return self.grad_to_param_mapping[id(grad)] def reset(self): - keys = [None] + list(range(self._world_size)) - self._params = {rank: [] for rank in keys} - self._num_elements_in_bucket = {rank: 0 for rank in keys} - - def reset_by_rank(self, reduce_rank=None): - self._params[reduce_rank] = [] - self._num_elements_in_bucket[reduce_rank] = 0 - - def get_grad(self, reduce_rank: int = None): - param_list = self.get_param(reduce_rank) - for param in param_list: - # the param must have grad for reduction - assert param.grad is not None, f'Parameter of size ({param.size()}) has None grad, cannot be reduced' - return [param.grad for param in param_list] - - def get_param(self, reduce_rank: int = None): - return self._params[reduce_rank] + self.grad_to_param_mapping = dict() + self._num_elements_in_bucket = 0 + self._param_list = [] + self._padding_size = [] + self._grad_in_bucket = dict() + for rank in range(self._world_size): + self._grad_in_bucket[rank] = [] diff --git a/colossalai/zero/low_level/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py index 942d7186e..0b86ec8ca 100644 --- a/colossalai/zero/low_level/bookkeeping/gradient_store.py +++ b/colossalai/zero/low_level/bookkeeping/gradient_store.py @@ -1,88 +1,92 @@ from typing import List from torch import Tensor +from torch._utils import _flatten_dense_tensors from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args): + def __init__(self, *args, partition_grad: bool = False): super().__init__(*args) - # bookkeeping data structures - self._averaged_gradients = dict() - - # for backward reduction hooks - self._grad_acc_objs = [] - - def append_accumulate_grad_object(self, obj): """ - Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not - be attached successfully. + self._grads_of_params mapping the paramater and its gradient slices + data structure: + { + group_id:{ + param_id: [grad_rank0, grad_rank1, ...] + } + } + """ + self._grads_of_params = dict() + # for zero2, it's `param_id: [grad_local_rank]` + self._working_index = 0 if partition_grad else self._local_rank - :param obj: An object of :class:`AccumulateGrad` class - :type obj: :class:`AccumulateGrad` + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + """Return list of gradient slices of a specific parameter + + Args: + group_id (int): The index of a parameter group + param_id (int): The id of a parameter + + Returns: + List: the list of gradient slices of a parameter. """ - self._grad_acc_objs.append(obj) + if group_id in self._grads_of_params: + if param_id in self._grads_of_params[group_id]: + return self._grads_of_params[group_id][param_id] + # the param has no grad, for instance, in layer drop + return [] - def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]: - """ - Return average gradients of a parameter group - - :param group_id: The index of parameter group - :type group_id: int - - :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. - :rtype: List[torch.Tensor] - """ - if group_id not in self._averaged_gradients: - self._averaged_gradients[group_id] = [] - - return self._averaged_gradients[group_id] - - def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: - """ - Append an average gradient to the list of averaged gradients of a parameter group - - :param group_id: The index of a parameter group - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor: torch.Tensor + def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int): + """Append a gradient slice to the parameter's gradient slice list + Args: + grad (Tensor): The gradient slice to append to list + group_id (int): The index of a parameter group + param_id (int): The id of a parameter """ - if group_id in self._averaged_gradients: - self._averaged_gradients[group_id].append(tensor) + if group_id not in self._grads_of_params: + self._grads_of_params[group_id] = dict() + if param_id not in self._grads_of_params[group_id]: + self._grads_of_params[group_id][param_id] = [grad] else: - self._averaged_gradients[group_id] = [tensor] + self._grads_of_params[group_id][param_id].append(grad) - def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None: - """ - Add an average gradient to the list of averaged gradients of a parameter group + def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): + """For old gradient accumulation, not in use now. + Add a gradient slice on an existing slice of the parameter's gradient - :param group_id: The index of a parameter group - :param tensor_idx: The index of a tensor in the list of averaged gradients - :param tensor: A :class:`torch.Tensor` object - :type group_id: int - :type tensor_idx: int - :type tensor: torch.Tensor - - """ - self._averaged_gradients[group_id][tensor_idx].add_(tensor) - - def reset_average_gradients_by_group(self, group_id: int) -> None: - """ - Reset the bookkeeping data structure for averaged gradients to an empty list - - :param group_id: The index of a parameter group - :type group_id: int + Args: + grad (Tensor): The split gradient to append to list + grad_idx (int): The index of the existing slice + group_id (int): The index of a parameter group + param_id (int): The id of a parameter """ - self._averaged_gradients[group_id] = [] + self._grads_of_params[group_id][param_id][grad_idx].add_(grad) - def reset_all_average_gradients(self) -> None: + def get_working_grads_by_group_id(self, group_id: int) -> List: + """Return list of working gradient slices in the group + + Args: + group_id (int): The index of a parameter group + + Returns: + List: the list working gradient slices in the group """ - Reset the bookkeeping data structure for averaged gradients to an empty list - """ - self._averaged_gradients = dict() + + grad_list = [] + for param_grads in self._grads_of_params[group_id].values(): + grad_list.append(param_grads[self._working_index]) + + return grad_list + + def reset_grads_by_group_id(self, group_id: int): + self._grads_of_params[group_id] = dict() + + def reset_all_gradients(self): + self._grads_of_params = dict() diff --git a/colossalai/zero/low_level/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py index 1f3ba7cbc..63f7c5506 100644 --- a/colossalai/zero/low_level/bookkeeping/parameter_store.py +++ b/colossalai/zero/low_level/bookkeeping/parameter_store.py @@ -1,5 +1,3 @@ -from typing import List - from torch import Tensor from torch.distributed import ProcessGroup @@ -10,88 +8,43 @@ class ParameterStore(BaseStore): def __init__(self, torch_pg: ProcessGroup): super().__init__(torch_pg) - # param partitioning data structures - self._param_to_rank = dict() - self._rank_group_id_to_param_list = dict() - self._rank_group_id_to_flat_param = dict() - # param reduction data structures - self._is_param_reduced = dict() - self._reduced_param = [] + # record the padding size of each param + self._padding_map = dict() - def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: - """ - Set the mapping between parameter to rank, each parameter should be owned by a rank. + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - :param rank: The rank of which the process is responsible for updating the parameter - :type rank: int + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param + + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter """ - self._param_to_rank[tensor] = rank + self._padding_map[id(param)] = padding_size - def get_param_rank(self, tensor: Tensor) -> int: - """ - Gives the rank which the parameter belongs to + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - """ - return self._param_to_rank[tensor] + Args: + param (Tensor): The parameter - def belongs_to_current_rank(self, tensor) -> bool: - """ - Check whether a parameter is supposed to be updated by the process of the current rank - - :param tensor: A :class:`torch.Tensor` object - :type tensor: torch.Tensor - - :return: True if the parameter should be updated by the current rank. Otherwise false. - :rtype: bool + Returns: + int: the padding size of the parameter """ - tensor_rank = self._param_to_rank[tensor] - return tensor_rank == self._local_rank + return self._padding_map[id(param)] - def add_param_list_by_rank_group(self, rank, group_id, tensor_list) -> None: - if rank not in self._rank_group_id_to_param_list: - self._rank_group_id_to_param_list[rank] = dict() + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter - if group_id not in self._rank_group_id_to_param_list[rank]: - self._rank_group_id_to_param_list[rank][group_id] = [] + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ - self._rank_group_id_to_param_list[rank][group_id].extend(tensor_list) - - def get_params_by_rank_group(self, rank, group_id) -> List[Tensor]: - return self._rank_group_id_to_param_list[rank][group_id] - - def add_flat_param_by_rank_group(self, rank, group_id, tensor) -> None: - if rank not in self._rank_group_id_to_flat_param: - self._rank_group_id_to_flat_param[rank] = dict() - - self._rank_group_id_to_flat_param[rank][group_id] = tensor - - def get_flat_param_by_rank_group(self, rank, group_id) -> Tensor: - return self._rank_group_id_to_flat_param[rank][group_id] - - def is_param_reduced(self, tensor): - return self._is_param_reduced[tensor] - - def set_param_reduction_state(self, tensor, state): - self._is_param_reduced[tensor] = state - - def get_param_reduction_states(self): - return self._is_param_reduced - - def reset_previous_reduced_params(self): - self._reduced_param = [] - - def add_previous_reduced_param(self, tensor): - self._reduced_param.append(tensor) - - def clear_grads_of_previous_reduced_params(self): - if len(self._reduced_param) > 0: - for param in self._reduced_param: - param.grad = None - self.reset_previous_reduced_params() + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index ee03c0f0a..8743cab33 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -1,4 +1,5 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +from contextlib import contextmanager from functools import partial from typing import Optional @@ -16,6 +17,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor import ColoParameter, ProcessGroup +from colossalai.utils import conditional_context from colossalai.utils.cuda import get_current_device from ._utils import ( @@ -23,12 +25,10 @@ from ._utils import ( compute_norm, flatten, has_inf_or_nan, - reduce_tensor_dp_group, release_param_grad, - split_by_dtype, - sync_param, + sync_tensor, ) -from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +from .bookkeeping import BucketStore, GradientStore, ParameterStore class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): @@ -50,7 +50,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): def check_local_overflow(self) -> bool: for group_id in range(self.num_working_param_groups): - for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id): + for avg_grad in self.grad_store.get_working_grads_by_group_id(group_id): if avg_grad is not None and has_inf_or_nan(avg_grad): return True return False @@ -77,14 +77,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): overlap_communication: bool = False, partition_grad: bool = False, # stage 2 flag cpu_offload: bool = False, # cpu offload + grad_accumulate_interval: int = 1, forced_dtype: Optional[torch.dtype] = None): - # TODO: add support for - # 1. fp16 master weights - # 2. contiguous gradients - # 3. cpu offload - # 4. support when some parameters requires_grad = False - # 5. support layer drop + assert not (partition_grad and grad_accumulate_interval > 1), \ + "gradient accumulation is not compatible with ZeRO-2" super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) self._dtype = self.optim.param_groups[0]['params'][0].dtype self._logger = get_dist_logger() @@ -95,6 +92,11 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._cpu_offload = cpu_offload + # grad accumulation + self.require_grad_sync = True + self._accumulate_intervel = grad_accumulate_interval + self._accumulate_step = 0 + colo_pg = self._search_colo_process_group() if isinstance(colo_pg, ProcessGroup): self._local_rank = colo_pg.dp_local_rank() @@ -122,7 +124,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # working and master params for mixed precision training self._working_param_groups = dict() - self._master_flat_param_groups_of_current_rank = dict() + self._master_param_groups_of_current_rank = dict() # communication params self._overlap_communication = overlap_communication @@ -145,7 +147,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # ParameterStore will manage the tensor buffers used for zero # it will not manage the tensors used by mixed precision training self._param_store = ParameterStore(self._dp_torch_group) - self._grad_store = GradientStore(self._dp_torch_group) + self._grad_store = GradientStore(self._dp_torch_group, partition_grad=partition_grad) self._bucket_store = BucketStore(self._dp_torch_group) # iterate over the param group in the optimizer @@ -160,55 +162,17 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # add the working params to working_param_groups for bookkeeping self._working_param_groups[group_id] = group_params - # assign parameters to ranks - # the params in the list are sorted - params_per_rank = self._partition_param_list(group_params) + master_param_current_rank = self._create_master_param_current_rank(group_params) - # store the mapping between param to rank - # each param should belong to only one rank - for rank, params in enumerate(params_per_rank): - self._param_store.add_param_list_by_rank_group(rank, group_id, params) - for param in params: - self._param_store.set_param_to_rank(param, rank) - - # move to cpu to make room to create the flat tensor - # move_tensor(params, device='cpu') - for param in group_params: - param.data = param.data.cpu() - - # flatten the reordered tensors - for rank in range(self._world_size): - tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) - with torch.no_grad(): - flat_tensor = flatten(tensor_list) - flat_tensor = flat_tensor.data.cuda() - self._param_store.add_flat_param_by_rank_group(rank, group_id, flat_tensor) - - # sync parameters - for rank in range(self._world_size): - flat_tensor = self._param_store.get_flat_param_by_rank_group(rank, group_id) - tensor_list = self._param_store.get_params_by_rank_group(rank, group_id) - sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) - - # create a copy of fp32 master weights of the parameters for which this rank is responsible - working_flat_current_rank = self._param_store.get_flat_param_by_rank_group(self._local_rank, group_id) - master_flat_current_rank = working_flat_current_rank.float() - device = 'cpu' if self._cpu_offload else get_current_device() - master_flat_current_rank = master_flat_current_rank.to(device) - master_flat_current_rank.requires_grad = True - self._master_flat_param_groups_of_current_rank[group_id] = master_flat_current_rank + self._master_param_groups_of_current_rank[group_id] = master_param_current_rank # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors # managed by this data parallel rank - param_group['params'] = [master_flat_current_rank] + param_group['params'] = master_param_current_rank - # set reduction state - for param in self._working_param_groups[group_id]: - self._param_store.set_param_reduction_state(param, False) - - # initialize communication stream for - # communication-computation overlapping + # intialize communication stream for + # communication-compuation overlapping if self._overlap_communication: self._comm_stream = torch.cuda.Stream() @@ -265,29 +229,36 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): raise RuntimeError("All parameters should be ColoParameter if you use ColoParameter.") return colo_pg - def _partition_param_list(self, param_list): - params_per_rank = [[] for _ in range(self._world_size)] - numel_per_rank = [0 for _ in range(self._world_size)] + def _create_master_param_current_rank(self, param_list): + # split each param evenly by world size + params_current_rank = [] + device = 'cpu' if self._cpu_offload else get_current_device() - # partition the parameters in a greedy fashion - sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for param in sorted_params: - # allocate this parameter to the rank with - # the smallest numel for load balancing purpose - rank_to_go = numel_per_rank.index(min(numel_per_rank)) - params_per_rank[rank_to_go].append(param) - numel_per_rank[rank_to_go] += param.numel() + for param in reversed(param_list): + padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size + self._param_store.record_param_padding_size(param, padding_size) - if self._verbose: - self._logger.info(f'Number of elements on ranks: {numel_per_rank}', ranks=[0]) - return params_per_rank + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + else: + padding_param = param.data.view(-1) + splited_params = padding_param.split(padding_param.numel() // self._world_size) + + splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device) + params_current_rank.append(splited_param_current_rank) + self._param_store.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank ########################### # Backward Reduction Hook # ########################### - def _grad_handler(self, param, grad, reduce_rank): - self._add_to_reduction_bucket(param, reduce_rank) + def _grad_handler(self, param, group_id, grad): + # if run with no_sync context, would not sync grad when backward + if self.require_grad_sync: + self._add_to_bucket(param, group_id) return grad def _attach_reduction_hook(self): @@ -297,149 +268,96 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): param_group = self._working_param_groups[group_id] for param in param_group: if param.requires_grad: - # determines the reduction destination rank - # this is only valid for stage 2 - # dst_rank = None means using all-reduce - # else using reduce - if self._partition_grads: - reduce_rank = self._param_store.get_param_rank(param) - else: - reduce_rank = None - - param.register_hook(partial(self._grad_handler, param, reduce_rank=reduce_rank)) - - def _reduce_tensor_bucket(self, bucket: TensorBucket, reduce_rank): - if self._overlap_communication: - torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() - - with torch.cuda.stream(stream): - flat = bucket.flatten() - reduce_global_rank = None - if reduce_rank is not None: - reduce_global_rank = self._dp_global_ranks[reduce_rank] - reduced_flat = reduce_tensor_dp_group(tensor=flat, - dtype=self._communication_dtype, - dst_local_rank=reduce_rank, - dst_global_rank=reduce_global_rank, - group=self._dp_torch_group) - - # update the reduced tensor - if reduce_rank is None or reduce_rank == self._local_rank: - bucket.unflatten_and_copy(reduced_flat) - - def _reduce_tensor_list_with_one_dtype(self, tensor_list, bucket_size, reduce_rank): - param_bucket = TensorBucket(size=bucket_size) - - for tensor in tensor_list: - param_bucket.add_to_bucket(tensor, allow_oversize=True) - - if param_bucket.is_full_or_oversized(): - self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - param_bucket.empty() - - if not param_bucket.is_empty(): - self._reduce_tensor_bucket(bucket=param_bucket, reduce_rank=reduce_rank) - - def _reduce_grads(self, reduce_rank, grads, bucket_size): - grad_buckets_by_dtype = split_by_dtype(grads) - - for tensor_list in grad_buckets_by_dtype: - self._reduce_tensor_list_with_one_dtype(tensor_list=tensor_list, - bucket_size=bucket_size, - reduce_rank=reduce_rank) + param.register_hook(partial(self._grad_handler, param, group_id)) ####################### # Reduction Functions # ####################### - def _run_reduction(self, reduce_rank=None): - # reduce grads - self._reduce_grads(reduce_rank=reduce_rank, - grads=self._bucket_store.get_grad(reduce_rank=reduce_rank), - bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank)) + def _run_reduction(self): + if self._bucket_store.num_elements_in_bucket() > 0: + self._bucket_store.build_grad_in_bucket() + flat_grads = self._bucket_store.get_flatten_grad() + flat_grads /= self._world_size + if self._overlap_communication: + stream = self._comm_stream + else: + stream = torch.cuda.current_stream() - # use communication stream if overlapping - # communication with computation - if self._overlap_communication: - stream = self._comm_stream - else: - stream = torch.cuda.current_stream() + with torch.cuda.stream(stream): + group_id = self._bucket_store.current_group_id - with torch.cuda.stream(stream): - params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) + grad_dtype = flat_grads.dtype + if self._communication_dtype is not None: + flat_grads = flat_grads.to(self._communication_dtype) - for param in params_in_bucket: - # the is_param_reduced flag should be False showing that - # this param is not reduced before calling self._reduce_grads_by_rank - is_param_reduced = self._param_store.is_param_reduced(param) + if not self._partition_grads: + dist.all_reduce(flat_grads, group=self._dp_torch_group) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has been reduced, ' + \ - 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) + flat_grads_per_rank = flat_grads.split(flat_grads.numel() // self._world_size) + grad_in_bucket = self._bucket_store.get_grad() - # update the flag - self._param_store.set_param_reduction_state(param, True) + for rank, grad_list in grad_in_bucket.items(): + sync_tensor(flat_grads_per_rank[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) - # if partition grads = True - # we do not keep the gradient after reduction - if self._partition_grads and not self._param_store.belongs_to_current_rank(param): - if self._overlap_communication: - # we need to keep this gradient for now as reduction may - # be completed yet since it is using a different cuda stream - self._param_store.add_previous_reduced_param(param) - else: - param.grad = None + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=self._dp_torch_group) - self._bucket_store.reset_by_rank(reduce_rank) + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) - def _add_to_reduction_bucket(self, param, reduce_rank=None): + grad_in_bucket_current_rank = self._bucket_store.get_grad()[self._local_rank] + sync_tensor(recieved_grad, grad_in_bucket_current_rank) + for grad in grad_in_bucket_current_rank: + param_id = self._bucket_store.get_param_id_of_grad(grad) + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + + self._bucket_store.reset() + + def _add_to_bucket(self, param, group_id): param_size = param.numel() # check if the bucket is full # if full, will reduce the grads already in the bucket + # or got a grad of param from another group # after reduction, the bucket will be empty - if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: - self._run_reduction(reduce_rank) + if self._bucket_store.num_elements_in_bucket() + param_size > self._reduce_bucket_size or \ + group_id != self._bucket_store.current_group_id: + self._run_reduction() - # the param must not be reduced to ensure correctness - is_param_reduced = self._param_store.is_param_reduced(param) - if is_param_reduced: - msg = f'Parameter of size ({param.size()}) has already been reduced, ' \ - + 'duplicate reduction will lead to arithmetic incorrectness' - raise RuntimeError(msg) - - self._bucket_store.add_num_elements_in_bucket(param_size, reduce_rank) - self._bucket_store.add_param(param, reduce_rank) + padding_size = self._param_store.get_param_padding_size(param) + self._bucket_store.add_param_grad(group_id, param, padding_size) ################################ # torch.optim.Optimizer methods ################################ - def backward(self, loss, retain_graph=False, sync_grad=True): + def backward(self, loss, retain_graph=False): if self.mixed_precision_mixin is not None: loss = self.mixed_precision_mixin.pre_backward(loss) - loss.backward(retain_graph=retain_graph) - # finish gradient reduction - if not self._partition_grads: - self._reduce_grad_stage1() - else: - # TODO: support async comm in reduce - self._reduce_grad_stage2() + self._accumulate_step += 1 + no_sync = self._accumulate_step < self._accumulate_intervel + with conditional_context(self.no_sync(), enable=no_sync): + loss.backward(retain_graph=retain_graph) + + if no_sync: + return + + self._reduce_grad(self._partition_grads) # clear reduced grads if self._overlap_communication: torch.cuda.synchronize() - self._param_store.clear_grads_of_previous_reduced_params() - # gradient synchronization - if sync_grad: - self._sync_grad() + self.zero_grad() def zero_grad(self, set_to_none=True): """ @@ -467,68 +385,86 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' + if not self._accumulate_step == self._accumulate_intervel: + return + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): - self._grad_store.reset_all_average_gradients() + self._grad_store.reset_all_gradients() if self._verbose: self._logger.info(f'Found overflow. Skip step') self.zero_grad() + self._accumulate_step -= 1 return - # copy the grad of working param to master param - single_grad_partition_groups = [] + # record all grads for unscale and clip + grad_partition_groups = [] norm_groups = [] + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + real_working_params = dict() + real_master_params = dict() + + grad_index = 0 if self._partition_grads else self._local_rank + for group_id in range(self.num_param_groups): + master_params = self._master_param_groups_of_current_rank[group_id] + real_working_params[group_id] = [] + real_master_params[group_id] = [] + for splited_param in master_params: + working_param = self._param_store.master_to_working_param[id(splited_param)] + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + if len(grads) > 0: + real_working_params[group_id].append(working_param) + grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) + # compute norm - norm_group = compute_norm(gradients=self._grad_store.get_averaged_gradients_by_group(group_id), - params=self._param_store.get_params_by_rank_group(group_id=group_id, - rank=self._local_rank), + working_grads = self._grad_store.get_working_grads_by_group_id(group_id) + norm_group = compute_norm(gradients=working_grads, + params=real_working_params[group_id], dp_group=self._dp_torch_group, mp_group=self._mp_torch_group) norm_groups.append(norm_group) - # create flat gradient for the flat fp32 master params - working_avg_grads = self._grad_store.get_averaged_gradients_by_group(group_id) - flat_working_avg_grads = flatten(working_avg_grads) + self._grad_store.reset_grads_by_group_id(group_id) - dtype = self._master_flat_param_groups_of_current_rank[group_id].dtype - flat_master_avg_grads = flat_working_avg_grads.to(dtype) - - param_shape = self._master_flat_param_groups_of_current_rank[group_id].shape - assert param_shape == flat_master_avg_grads.shape, \ - f'fp32 param and grad have different shape {param_shape} vs {flat_master_avg_grads.shape}' - - single_grad_partition_groups.append(flat_master_avg_grads) - device = self._master_flat_param_groups_of_current_rank[group_id].device - self._master_flat_param_groups_of_current_rank[group_id].grad = flat_master_avg_grads.to(device) - self._grad_store.reset_average_gradients_by_group(group_id) + # update the params in the optimizer + self.optim.param_groups[group_id]['params'] = real_master_params[group_id] # unscale and clip grads global_norm = calculate_global_norm_from_list(norm_list=norm_groups) - self._unscale_and_clip_grads(single_grad_partition_groups, global_norm) + self._unscale_and_clip_grads(grad_partition_groups, global_norm) # update the parameters self.optim.step() - # release the master grad - release_param_grad(self._master_flat_param_groups_of_current_rank.values()) + + # release the grad + grad_partition_groups = [] + for group_id in range(self.num_param_groups): + release_param_grad(self._master_param_groups_of_current_rank[group_id]) # update working partition updated by the current rank - for group_id in range(len(self._working_param_groups)): - working_param = self._param_store.get_flat_param_by_rank_group(rank=self._local_rank, group_id=group_id) - master_param = self._master_flat_param_groups_of_current_rank[group_id] - working_param.data.copy_(master_param) - - # broadcast the updated model weights - handles = [] for group_id in range(self.num_param_groups): - for index in range(self._world_size): - rank = self._dp_global_ranks[index] - working_param = self._param_store.get_flat_param_by_rank_group(rank=index, group_id=group_id) - handle = dist.broadcast(working_param, src=rank, group=self._dp_torch_group, async_op=True) - handles.append(handle) + master_working_param = self.optim.param_groups[group_id]['params'] - for handle in handles: - handle.wait() + for idx, splited_param in enumerate(master_working_param): + full_master_param = [torch.zeros_like(splited_param).cuda() for _ in range(self._world_size)] + dist.all_gather(full_master_param, splited_param.cuda(), group=self._dp_torch_group) + working_param = real_working_params[group_id][idx] + full_master_param = flatten(full_master_param)[:working_param.numel()].reshape_as(working_param) + working_param.data.copy_(full_master_param) + + self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id] + + # reset accumulate step + self._accumulate_step = 0 ############################# # Mixed Precision Utilities # @@ -553,49 +489,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # Gradient Synchronization # ############################ - def _sync_grad(self): - # update param already reduced flag - reduction_states = self._param_store.get_param_reduction_states() - for tensor, _ in reduction_states.items(): - reduction_states[tensor] = False - - # accumulate gradient - for group_id in range(self.num_param_groups): - param_group = self._param_store.get_params_by_rank_group(self._local_rank, group_id) - - avg_gradients_group = self._grad_store.get_averaged_gradients_by_group(group_id) - - param_idx = 0 - for param in param_group: - if param.grad is not None: - if len(avg_gradients_group) == param_idx: - self._grad_store.append_average_gradient_by_group(group_id, param.grad) - else: - self._grad_store.add_average_gradient_by_group(group_id, param_idx, param.grad) - param_idx += 1 - - # the gradients needed are stored in the avg_gradients buffer - # thus, can clear this - self.zero_grad() - - def _reduce_grad_stage1(self): - # if not overlapping communication (no reduction hook is attached) + def _reduce_grad(self, partition_grad): + # if not overlapping communication (no reduction hook is attached) when zero1 # we need to manually reduce these gradients - if not self._overlap_communication: + if not partition_grad and not self._overlap_communication: for group_id in range(len(self._working_param_groups)): param_group = self._working_param_groups[group_id] for param in param_group: if param.grad is not None: - self._add_to_reduction_bucket(param) + self._add_to_bucket(param, group_id) - # we need to reduce the gradients - # left in the communication bucket + # run reduction self._run_reduction() - def _reduce_grad_stage2(self): - # when partition_grads is True, reduction hooks - # are attached in the __init__ function, so we - # only need to reduce the gradients - # left in the communication bucket - for reduce_rank in range(self._world_size): - self._run_reduction(reduce_rank) + # this context comes from pytorch DDP + @contextmanager + def no_sync(self): + old_require_grad_sync = self.require_grad_sync + self.require_grad_sync = False + try: + yield + finally: + self.require_grad_sync = old_require_grad_sync diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index eedd8c59a..79f98a4c9 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -11,14 +11,9 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo # These models are not compatible with AMP -_AMP_ERR_MODELS = ['timm_convit', 'dlrm', 'deepfm_interactionarch', 'deepfm_simpledeepfmnn'] +_AMP_ERR_MODELS = ['timm_convit', 'deepfm_interactionarch'] # These models have no parameters -_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch', 'deepfm_overarch', 'deepfm_sparsearch', 'dlrm_sparsearch'] -# These models will get stuck -_STUCK_MODELS = [ - 'diffusers_vq_model', 'transformers_albert', 'transformers_albert_for_pretraining', 'transformers_bert', - 'transformers_bert_for_pretraining', 'transformers_gpt_double_heads' -] +_LOW_LEVEL_ZERO_ERR_MODELS = ['dlrm_interactionarch'] def run_fn(stage, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: @@ -58,7 +53,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): """ passed_models = [] failed_info = {} # (model_name, error) pair - ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS + ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS skipped_models = [] for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index c264a8077..ac1f677f9 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -39,37 +39,37 @@ def exam_zero_1_2_grad_acc(): overlap_communication=True, initial_scale=32, clip_grad_norm=1.0, + grad_accumulate_interval=2, verbose=True) zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer, overlap_communication=True, partition_grad=True, initial_scale=32, - clip_grad_norm=1.0) + clip_grad_norm=1.0, + grad_accumulate_interval=2) # create data seed_all(2021 + local_rank) input_data1 = torch.randn(32, 128).cuda() input_data2 = torch.randn(32, 128).cuda() - def fwd_bwd_func(number, cur_data): + def fwd_bwd_func(number, cur_data, check_flag): # zero-dp forward zero1_output = zero1_model(cur_data) zero2_output = zero2_model(cur_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.sum().float(), sync_grad=False) - zero2_optimizer.backward(zero2_output.sum().float(), sync_grad=False) + zero1_optimizer.backward(zero1_output.sum().float()) + zero2_optimizer.backward(zero2_output.sum().float()) - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) + if check_flag: + for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): + if z2p.grad is not None: + # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) + assert torch.equal(z1p.grad, z2p.grad) - zero1_optimizer._sync_grad() - zero2_optimizer._sync_grad() - - fwd_bwd_func(0, input_data1) - fwd_bwd_func(1, input_data2) + fwd_bwd_func(0, input_data1, True) + fwd_bwd_func(1, input_data2, False) # step zero1_optimizer.step() @@ -101,7 +101,8 @@ def exam_zero_1_grad_acc(): zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=False, reduce_bucket_size=262144, - clip_grad_norm=1.0) + clip_grad_norm=1.0, + grad_accumulate_interval=2) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1) @@ -115,13 +116,19 @@ def exam_zero_1_grad_acc(): zero_output = zero_model(cur_data) # torch-ddp forward - torch_output = torch_model(cur_data) - assert torch.equal(zero_output, torch_output) # zero-dp backward - zero_optimizer.backward(zero_output.sum().float(), sync_grad=False) + zero_optimizer.backward(zero_output.sum().float()) # torch-ddp backward - torch_output.sum().backward() + if number < 1: + with torch_model.no_sync(): + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + torch_output.sum().backward() + else: + torch_output = torch_model(cur_data) + assert torch.equal(zero_output, torch_output) + torch_output.sum().backward() if check_flag: # check grad @@ -129,8 +136,6 @@ def exam_zero_1_grad_acc(): # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) assert torch.equal(p.grad, z1p.grad) - zero_optimizer._sync_grad() - fwd_bwd_func(0, input_data1, True) fwd_bwd_func(1, input_data2, False) @@ -148,7 +153,8 @@ def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') exam_zero_1_grad_acc() - exam_zero_1_2_grad_acc() + # gradient accumulation is not compatible with ZeRO-2 + # exam_zero_1_2_grad_acc() @pytest.mark.dist diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 8e2206fe6..5a0609bff 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -2,6 +2,7 @@ import copy import pytest import torch +import torch.distributed as dist import torch.nn as nn from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close @@ -16,8 +17,9 @@ class MlpModel(nn.Module): def __init__(self): super(MlpModel, self).__init__() - self.linear1 = nn.Linear(128, 256) - self.linear2 = nn.Linear(256, 512) + self.linear1 = nn.Linear(123, 253) + self.linear_drop = nn.Linear(253, 253) + self.linear2 = nn.Linear(253, 512) def forward(self, x): x = self.linear1(x) @@ -41,6 +43,16 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32): assert_close(a, b, rtol=rtol, atol=atol) +def split_ddp_grad(grad, world_size): + with torch.no_grad(): + grad = grad.clone().detach().flatten() + padding_size = (world_size - grad.numel() % world_size) % world_size + if padding_size > 0: + grad = torch.nn.functional.pad(grad, [0, padding_size]) + splited_grad = grad.split(grad.numel() // world_size) + return splited_grad + + def exam_zero_1_2(): """ In this test, we want to test whether zero stage 1 and 2 @@ -72,23 +84,21 @@ def exam_zero_1_2(): initial_scale=128) # create data seed_all(2001 + local_rank) - input_data = torch.randn(32, 128).cuda() + input_data = torch.randn(32, 123).cuda() zero1_output = zero1_model(input_data) zero2_output = zero2_model(input_data) assert torch.equal(zero1_output, zero2_output) # zero-dp backward - zero1_optimizer.backward(zero1_output.mean().float(), sync_grad=False) - zero2_optimizer.backward(zero2_output.mean().float(), sync_grad=False) + zero1_optimizer.backward(zero1_output.mean().float()) + zero2_optimizer.backward(zero2_output.mean().float()) - for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()): - if z2p.grad is not None: - # print(local_rank, n, z1p.shape, torch.max(z2p.grad), torch.max(torch.abs(z1p.grad - z2p.grad))) - assert torch.equal(z1p.grad, z2p.grad) - - zero1_optimizer._sync_grad() - zero2_optimizer._sync_grad() + # check grad + z1g_list = zero1_optimizer._grad_store.get_working_grads_by_group_id(0) + z2g_list = zero2_optimizer._grad_store.get_working_grads_by_group_id(0) + for z1g, z2g in zip(z1g_list, z2g_list): + assert torch.equal(z1g, z2g) # step zero1_optimizer.step() @@ -100,7 +110,7 @@ def exam_zero_1_2(): @parameterize('dtype', [torch.float16, torch.bfloat16]) -def exam_zero_1_torch_ddp(dtype: torch.dtype): +def exam_zero_1_torch_ddp(world_size, dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -116,7 +126,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): torch_model = MlpModel().cuda() zero_model = copy.deepcopy(torch_model).to(dtype) - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() + torch_model = DDP(torch_model.cuda(), static_graph=True).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) @@ -133,7 +143,7 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): seed_all(1453 + local_rank) # create - input_data = torch.rand(32, 128).cuda() + input_data = torch.rand(32, 123).cuda() # zero-dp forward zero_output = zero_model(input_data.to(dtype)) @@ -143,17 +153,20 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward - zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) + zero_optimizer.backward(zero_output.mean().float()) # torch-ddp backward torch_output.mean().backward() # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - loose_close(p.grad, z1p.grad, dtype=dtype) + if p.grad is not None: + zero_grad_list = zero_optimizer._grad_store.get_partitioned_gradients_by_param_id(0, id(z1p)) + torch_grad_list = split_ddp_grad(p.grad, world_size) + for zero_grad, torch_grad in zip(zero_grad_list, torch_grad_list): + loose_close(zero_grad, torch_grad, dtype=dtype) # zero-dp step - zero_optimizer._sync_grad() zero_optimizer.step() # torch ddp step @@ -161,14 +174,13 @@ def exam_zero_1_torch_ddp(dtype: torch.dtype): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # print(n, torch.max(torch.abs(p.data - z1p.data))) loose_close(p.data, z1p.data, dtype=dtype) def run_dist(rank, world_size, port): colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') - exam_zero_1_torch_ddp() + exam_zero_1_torch_ddp(world_size=world_size) exam_zero_1_2()