diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py index 8a9128a18..b166752cc 100644 --- a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py +++ b/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py @@ -6,7 +6,6 @@ from .base_store import BaseStore class GradientStore(BaseStore): - def __init__(self, *args): super().__init__(*args) # bookkeeping data structures @@ -15,7 +14,7 @@ class GradientStore(BaseStore): # for backward reduction hooks self._grad_acc_objs = [] - def add_accumulate_grad_object(self, obj): + def append_accumulate_grad_object(self, obj): """ Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not be attached successfully. @@ -36,10 +35,12 @@ class GradientStore(BaseStore): :return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter. :rtype: List[torch.Tensor] """ + if group_id not in self._averaged_gradients: + self._averaged_gradients[group_id] = [] return self._averaged_gradients[group_id] - def add_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: + def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None: """ Append an average gradient to the list of averaged gradients of a parameter group @@ -55,6 +56,22 @@ class GradientStore(BaseStore): else: self._averaged_gradients[group_id] = [tensor] + def add_average_gradient_by_group( + self, group_id: int, tensor_idx: int, tensor: Tensor + ) -> None: + """ + Add an average gradient to the list of averaged gradients of a parameter group + + :param group_id: The index of a parameter group + :param tensor_idx: The index of a tensor in the list of averaged gradients + :param tensor: A :class:`torch.Tensor` object + :type group_id: int + :type tensor_idx: int + :type tensor: torch.Tensor + + """ + self._averaged_gradients[group_id][tensor_idx].add_(tensor) + def reset_average_gradients_by_group(self, group_id: int) -> None: """ Reset the bookkeeping data structure for averaged gradients to an empty list diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/sharded_optim/low_level_optim.py index 89f5f9fad..f5e03ce28 100644 --- a/colossalai/zero/sharded_optim/low_level_optim.py +++ b/colossalai/zero/sharded_optim/low_level_optim.py @@ -550,20 +550,24 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): reduction_states[tensor] = False # accumulate gradient - avg_gradients = self._grad_store._averaged_gradients for group_id in range(self.num_param_groups): param_group = self._param_store.get_fp16_params_by_rank_group(self._local_rank, group_id) - if group_id not in avg_gradients: - avg_gradients[group_id] = [] + avg_gradients_group = self._grad_store.get_averaged_gradients_by_group( + group_id + ) param_idx = 0 for param in param_group: if param.grad is not None: - if len(avg_gradients[group_id]) == param_idx: - avg_gradients[group_id].append(param.grad) + if len(avg_gradients_group) == param_idx: + self._grad_store.append_average_gradient_by_group( + group_id, param.grad + ) else: - avg_gradients[group_id][param_idx].add_(param.grad) + self._grad_store.add_average_gradient_by_group( + group_id, param_idx, param.grad + ) param_idx += 1 # the gradients needed are stored in the avg_gradients buffer @@ -590,4 +594,4 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # only need to reduce the gradients # left in the communication bucket for reduce_rank in range(self._world_size): - self._run_reduction(reduce_rank) + self._run_reduction(reduce_rank) \ No newline at end of file