[ddp] add set_params_to_ignore for ColoDDP (#1122)

* add set_params_to_ignore for ColoDDP

* polish code

* fix zero hook v2

* add unit test

* polish docstr
This commit is contained in:
ver217
2022-06-16 12:54:46 +08:00
committed by GitHub
parent 3175bcb4d8
commit f0a954f16d
3 changed files with 117 additions and 1 deletions

View File

@@ -22,6 +22,7 @@ class ZeROHookV2(ParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
@@ -33,6 +34,7 @@ class ZeROHookV2(ParamOpHook):
self._gemini_manager.sample_model_data()
def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)