[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)

This commit is contained in:
Jiarui Fang
2022-11-17 14:43:49 +08:00
committed by GitHub
parent f8a7148dec
commit cc0ed7cf33
4 changed files with 12 additions and 10 deletions

View File

@@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer
@@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP):
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager)
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = []
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = {}