mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
@@ -8,6 +8,7 @@ import torch
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
@@ -24,7 +25,7 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
@@ -37,7 +38,7 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
def post_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
for p in params:
|
||||
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
|
||||
self._chunk_manager.trans_tensor_state(p, tensor_state)
|
||||
|
Reference in New Issue
Block a user