[zero]remove registered gradients hooks (#5687)

* remove registered hooks

fix

fix

fix zero

fix

fix

fix

fix

fix zero

fix zero

fix

fix

fix

* fix

fix

fix
This commit is contained in:
flybird11111
2024-05-07 12:01:38 +08:00
committed by GitHub
parent c25f83c85f
commit 77ec773388
7 changed files with 256 additions and 167 deletions

View File

@@ -6,7 +6,7 @@ from .base_store import BaseStore
class GradientStore(BaseStore):
def __init__(self, *args, partition_grad: bool = False):
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
super().__init__(*args)
"""
self._grads_of_params mapping the parameter and its gradient slices
@@ -18,9 +18,12 @@ class GradientStore(BaseStore):
}
"""
self._grads_of_params = dict()
# for zero2, it's `param_id: [grad_local_rank]`
# stage 2
self._partition_grads = partition_grad
# grad accumulation
self.require_grad_sync = require_grad_sync
self._working_index = 0 if partition_grad else self._local_rank
# for zero2, it's `param_id: [grad_local_rank]`
self.grad_to_param_mapping = dict()
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: