mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 13:59:08 +00:00
[zero]remove registered gradients hooks (#5687)
* remove registered hooks fix fix fix zero fix fix fix fix fix zero fix zero fix fix fix * fix fix fix
This commit is contained in:
@@ -6,7 +6,7 @@ from .base_store import BaseStore
|
||||
|
||||
|
||||
class GradientStore(BaseStore):
|
||||
def __init__(self, *args, partition_grad: bool = False):
|
||||
def __init__(self, *args, partition_grad: bool = False, require_grad_sync: bool = True):
|
||||
super().__init__(*args)
|
||||
"""
|
||||
self._grads_of_params mapping the parameter and its gradient slices
|
||||
@@ -18,9 +18,12 @@ class GradientStore(BaseStore):
|
||||
}
|
||||
"""
|
||||
self._grads_of_params = dict()
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
# stage 2
|
||||
self._partition_grads = partition_grad
|
||||
# grad accumulation
|
||||
self.require_grad_sync = require_grad_sync
|
||||
self._working_index = 0 if partition_grad else self._local_rank
|
||||
|
||||
# for zero2, it's `param_id: [grad_local_rank]`
|
||||
self.grad_to_param_mapping = dict()
|
||||
|
||||
def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List:
|
||||
|
Reference in New Issue
Block a user