mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[zero] refactor low level zero for shard evenly (#4030)
* refactor low level zero * fix zero2 and support cpu offload * avg gradient and modify unit test * refactor grad store, support layer drop * refactor bucket store, support grad accumulation * fix and update unit test of zero and ddp * compatible with tp, ga and unit test * fix memory leak and polish * add zero layer drop unittest * polish code * fix import err in unit test * support diffenert comm dtype, modify docstring style * polish code * test padding and fix * fix unit test of low level zero * fix pad recording in bucket store * support some models * polish
This commit is contained in:
@@ -1,88 +1,92 @@
|
||||
from typing import List
|
||||
|
||||
from torch import Tensor
|
||||
from torch._utils import _flatten_dense_tensors
|
||||
|
||||
from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
|
||||
def __init__(self, *args):
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
super().__init__(*args)
|
||||
# bookkeeping data structures
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
# for backward reduction hooks
|
||||
self._grad_acc_objs = []
|
||||
|
||||
def append_accumulate_grad_object(self, obj):
|
||||
"""
|
||||
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
|
||||
be attached successfully.
|
||||
self._grads_of_params mapping the paramater and its gradient slices
|
||||
data structure:
|
||||
{
|
||||
group_id:{
|
||||
param_id: [grad_rank0, grad_rank1, ...]
|
||||
}
|
||||
}
|
||||
"""
|
||||
self._grads_of_params = dict()
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
|
||||
:param obj: An object of :class:`AccumulateGrad` class
|
||||
:type obj: :class:`AccumulateGrad`
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
"""Return list of gradient slices of a specific parameter
|
||||
|
||||
Args:
|
||||
group_id (int): The index of a parameter group
|
||||
param_id (int): The id of a parameter
|
||||
|
||||
Returns:
|
||||
List: the list of gradient slices of a parameter.
|
||||
"""
|
||||
|
||||
self._grad_acc_objs.append(obj)
|
||||
if group_id in self._grads_of_params:
|
||||
if param_id in self._grads_of_params[group_id]:
|
||||
return self._grads_of_params[group_id][param_id]
|
||||
# the param has no grad, for instance, in layer drop
|
||||
return []
|
||||
|
||||
def get_averaged_gradients_by_group(self, group_id: int) -> List[Tensor]:
|
||||
"""
|
||||
Return average gradients of a parameter group
|
||||
|
||||
:param group_id: The index of parameter group
|
||||
:type group_id: int
|
||||
|
||||
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
|
||||
:rtype: List[torch.Tensor]
|
||||
"""
|
||||
if group_id not in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id] = []
|
||||
|
||||
return self._averaged_gradients[group_id]
|
||||
|
||||
def append_average_gradient_by_group(self, group_id: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Append an average gradient to the list of averaged gradients of a parameter group
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor: torch.Tensor
|
||||
def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int):
|
||||
"""Append a gradient slice to the parameter's gradient slice list
|
||||
|
||||
Args:
|
||||
grad (Tensor): The gradient slice to append to list
|
||||
group_id (int): The index of a parameter group
|
||||
param_id (int): The id of a parameter
|
||||
"""
|
||||
|
||||
if group_id in self._averaged_gradients:
|
||||
self._averaged_gradients[group_id].append(tensor)
|
||||
if group_id not in self._grads_of_params:
|
||||
self._grads_of_params[group_id] = dict()
|
||||
if param_id not in self._grads_of_params[group_id]:
|
||||
self._grads_of_params[group_id][param_id] = [grad]
|
||||
else:
|
||||
self._averaged_gradients[group_id] = [tensor]
|
||||
self._grads_of_params[group_id][param_id].append(grad)
|
||||
|
||||
def add_average_gradient_by_group(self, group_id: int, tensor_idx: int, tensor: Tensor) -> None:
|
||||
"""
|
||||
Add an average gradient to the list of averaged gradients of a parameter group
|
||||
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
|
||||
"""For old gradient accumulation, not in use now.
|
||||
Add a gradient slice on an existing slice of the parameter's gradient
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:param tensor_idx: The index of a tensor in the list of averaged gradients
|
||||
:param tensor: A :class:`torch.Tensor` object
|
||||
:type group_id: int
|
||||
:type tensor_idx: int
|
||||
:type tensor: torch.Tensor
|
||||
|
||||
"""
|
||||
self._averaged_gradients[group_id][tensor_idx].add_(tensor)
|
||||
|
||||
def reset_average_gradients_by_group(self, group_id: int) -> None:
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
|
||||
:param group_id: The index of a parameter group
|
||||
:type group_id: int
|
||||
Args:
|
||||
grad (Tensor): The split gradient to append to list
|
||||
grad_idx (int): The index of the existing slice
|
||||
group_id (int): The index of a parameter group
|
||||
param_id (int): The id of a parameter
|
||||
"""
|
||||
|
||||
self._averaged_gradients[group_id] = []
|
||||
self._grads_of_params[group_id][param_id][grad_idx].add_(grad)
|
||||
|
||||
def reset_all_average_gradients(self) -> None:
|
||||
def get_working_grads_by_group_id(self, group_id: int) -> List:
|
||||
"""Return list of working gradient slices in the group
|
||||
|
||||
Args:
|
||||
group_id (int): The index of a parameter group
|
||||
|
||||
Returns:
|
||||
List: the list working gradient slices in the group
|
||||
"""
|
||||
Reset the bookkeeping data structure for averaged gradients to an empty list
|
||||
"""
|
||||
self._averaged_gradients = dict()
|
||||
|
||||
grad_list = []
|
||||
for param_grads in self._grads_of_params[group_id].values():
|
||||
grad_list.append(param_grads[self._working_index])
|
||||
|
||||
return grad_list
|
||||
|
||||
def reset_grads_by_group_id(self, group_id: int):
|
||||
self._grads_of_params[group_id] = dict()
|
||||
|
||||
def reset_all_gradients(self):
|
||||
self._grads_of_params = dict()
|
||||
|
Reference in New Issue
Block a user