mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[ddp] add set_params_to_ignore for ColoDDP (#1122)
* add set_params_to_ignore for ColoDDP * polish code * fix zero hook v2 * add unit test * polish docstr
This commit is contained in:
@@ -22,6 +22,7 @@ class ZeROHookV2(ParamOpHook):
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
@@ -33,6 +34,7 @@ class ZeROHookV2(ParamOpHook):
|
||||
self._gemini_manager.sample_model_data()
|
||||
|
||||
def post_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p in params:
|
||||
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
|
||||
self._chunk_manager.trans_tensor_state(p, tensor_state)
|
||||
|
Reference in New Issue
Block a user