[zero]support zero2 with gradient accumulation (#4511)

* support gradient accumulation with zero2

* fix type
This commit is contained in:
LuGY
2023-08-25 13:44:07 +08:00
committed by GitHub
parent c0efc3ebcb
commit 839847b7d7
4 changed files with 61 additions and 28 deletions

View File

@@ -57,8 +57,8 @@ class GradientStore(BaseStore):
self._grads_of_params[group_id][param_id].append(grad)
def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int):
"""For old gradient accumulation, not in use now.
Add a gradient slice on an existing slice of the parameter's gradient
"""Add a gradient slice on an existing slice of the parameter's gradient
Used when no_sync is not activated.
Args:
grad (Tensor): The split gradient to append to list