mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[ddp] add set_params_to_ignore for ColoDDP (#1122)
* add set_params_to_ignore for ColoDDP * polish code * fix zero hook v2 * add unit test * polish docstr
This commit is contained in:
@@ -7,7 +7,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor.chunk import TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict
|
||||
from typing import Dict, Iterable
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
@@ -38,6 +38,8 @@ class ColoDDP(torch.nn.Module):
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
|
||||
@@ -55,6 +57,8 @@ class ColoDDP(torch.nn.Module):
|
||||
loss.backward()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if p.grad.device.type != "cpu":
|
||||
p.grad = p._saved_grad
|
||||
|
||||
@@ -99,6 +103,25 @@ class ColoDDP(torch.nn.Module):
|
||||
p._saved_grad.requires_grad_(False)
|
||||
p._saved_grad.zero_()
|
||||
|
||||
@staticmethod
|
||||
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
|
||||
"""Sets parameters to be ignored by DDP.
|
||||
This method must be called before initializing ColoDDP.
|
||||
|
||||
Example::
|
||||
>>> params_to_ignore = []
|
||||
>>> for p in module.parameters():
|
||||
>>> if should_ignore(p):
|
||||
>>> params_to_ignore.append(p)
|
||||
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
|
||||
>>> module = ColoDDP(module)
|
||||
|
||||
Args:
|
||||
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
|
||||
"""
|
||||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
|
||||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
@@ -114,6 +137,8 @@ class ColoDDPV2(ColoDDP):
|
||||
self.chunk_manager.create_group('fp32_param')
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
assert p.dtype == torch.half
|
||||
fp32_p = p.float().detach()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
@@ -133,6 +158,8 @@ class ColoDDPV2(ColoDDP):
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
continue
|
||||
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
|
@@ -22,6 +22,7 @@ class ZeROHookV2(ParamOpHook):
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
@@ -33,6 +34,7 @@ class ZeROHookV2(ParamOpHook):
|
||||
self._gemini_manager.sample_model_data()
|
||||
|
||||
def post_op(self, params):
|
||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||
for p in params:
|
||||
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
|
||||
self._chunk_manager.trans_tensor_state(p, tensor_state)
|
||||
|
Reference in New Issue
Block a user