[ddp] add set_params_to_ignore for ColoDDP (#1122)

* add set_params_to_ignore for ColoDDP

* polish code

* fix zero hook v2

* add unit test

* polish docstr
This commit is contained in:
ver217
2022-06-16 12:54:46 +08:00
committed by GitHub
parent 3175bcb4d8
commit f0a954f16d
3 changed files with 117 additions and 1 deletions

View File

@@ -7,7 +7,7 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.tensor.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict
from typing import Dict, Iterable
from colossalai.logging import get_dist_logger
@@ -38,6 +38,8 @@ class ColoDDP(torch.nn.Module):
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.dp_world_size = gpc.get_world_size(ParallelMode.DATA)
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
@@ -55,6 +57,8 @@ class ColoDDP(torch.nn.Module):
loss.backward()
torch.cuda.current_stream().wait_stream(self.comm_stream)
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
@@ -99,6 +103,25 @@ class ColoDDP(torch.nn.Module):
p._saved_grad.requires_grad_(False)
p._saved_grad.zero_()
@staticmethod
def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None:
"""Sets parameters to be ignored by DDP.
This method must be called before initializing ColoDDP.
Example::
>>> params_to_ignore = []
>>> for p in module.parameters():
>>> if should_ignore(p):
>>> params_to_ignore.append(p)
>>> ColoDDP.set_params_to_ignore(params_to_ignore)
>>> module = ColoDDP(module)
Args:
params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored.
"""
for p in params_to_ignore:
p._ddp_to_ignore = True
class ColoDDPV2(ColoDDP):
@@ -114,6 +137,8 @@ class ColoDDPV2(ColoDDP):
self.chunk_manager.create_group('fp32_param')
# TODO: get param order and filter unused params
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
assert p.dtype == torch.half
fp32_p = p.float().detach()
self.chunk_manager.append_tensor(p, 'fp16_param')
@@ -133,6 +158,8 @@ class ColoDDPV2(ColoDDP):
def _setup_grads_ptr(self):
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
continue
if self.chunk_manager.get_chunk(p).is_empty or not p.requires_grad:
p.grad = None
else:

View File

@@ -22,6 +22,7 @@ class ZeROHookV2(ParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
@@ -33,6 +34,7 @@ class ZeROHookV2(ParamOpHook):
self._gemini_manager.sample_model_data()
def post_op(self, params):
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
for p in params:
tensor_state = TensorState.HOLD if self._training_phase == TrainingPhase.FORWARD or not p.requires_grad else TensorState.HOLD_AFTER_BWD
self._chunk_manager.trans_tensor_state(p, tensor_state)