mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[Gemini] ZeROHookV2 -> GeminiZeROHook (#1972)
This commit is contained in:
@@ -14,7 +14,7 @@ from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
from .reducer import Reducer
|
||||
|
||||
@@ -210,7 +210,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.param_op_hook = GeminiZeROHook(gemini_manager)
|
||||
self.fp32_params: List[ColoTensor] = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
|
Reference in New Issue
Block a user