[ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored
This commit is contained in:
HELSON
2023-01-11 12:22:45 +08:00
committed by GitHub
parent a3e5496156
commit 7829aa094e
7 changed files with 56 additions and 30 deletions

View File

@@ -8,6 +8,7 @@ import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
class TrainingPhase(Enum):
@@ -24,7 +25,7 @@ class GeminiZeROHook(ColoParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
@@ -37,7 +38,7 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.record_model_data_volume()
def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
params = [p for p in params if not is_ddp_ignored(p)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)