[zero]support zero2 with gradient accumulation (#4511)

* support gradient accumulation with zero2

* fix type
This commit is contained in:
LuGY
2023-08-25 13:44:07 +08:00
committed by GitHub
parent c0efc3ebcb
commit 839847b7d7
4 changed files with 61 additions and 28 deletions

View File

@@ -57,8 +57,8 @@ class GradientStore(BaseStore):
self._grads_of_params[group_id][param_id].append(grad)
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""For old gradient accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient
"""Add a gradient slice on an existing slice of the parameter's gradient
Used when no_sync is not activated.
Args:
grad (Tensor): The split gradient to append to list

View File

@@ -277,7 +277,11 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor(flat_grads_per_rank[rank], grad_list)
for grad in grad_list:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id,
param_id)) < self._world_size:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // self._world_size))
@@ -291,7 +295,10 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
sync_tensor(recieved_grad, grad_in_bucket_current_rank)
for grad in grad_in_bucket_current_rank:
param_id = self._bucket_store.get_param_id_of_grad(grad)
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < 1:
self._grad_store.append_gradients_by_param_id(grad, group_id, param_id)
else:
self._grad_store.add_gradients_by_param_id(grad, 0, group_id, param_id)
self._bucket_store.reset()
@@ -315,7 +322,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
"ZeRO2(partition_grads) and no_sync are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)

View File

@@ -1,5 +1,41 @@
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
## Examples of ZeRO and gradient accumulation
The code below only shows a typical gradient accumulation process, and it drops a lot of details, such as the processing of loss.
```python
# examples of ZeRO1 with gradient accumulation
...
outputs = model(input)
loss = SomeLoss(outputs)
if (idx + 1) % ACCUMULATE_STEP != 0:
with booster.no_sync(model, optimizer):
# under this context, the gradient would not sync when backward,
# left each rank having different gradient.
# It saves the backward time
booster.backward(loss, optimizer)
continue
else:
# need to sync all the accumulated gradient
booster.backward(loss, optimizer):
optimizer.step()
...
```
```python
# example of ZeRO2 with gradient accumulation
...
outputs = model(input)
loss = SomeLoss(outputs)
# ZeRO2 split the gradients and can NOT accumulate gradient with syncing.
booster.backward(loss, optimizer)
if (idx + 1) % ACCUMULATE_STEP == 0:
optimizer.step()
...
```
## Design:
### Notion
@@ -25,11 +61,11 @@ The data structure looks like this:
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
# g-0 means flatten([g-00, g-10])
# g-X0 means flatten([g-00, g-10])
{
0: [g-0],
1: [g-1],
2: [g-2]
0: [g-X0],
1: [g-X1],
2: [g-X2]
}
```
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.