[ddp] add is_ddp_ignored (#2434)

[ddp] rename to is_ddp_ignored
This commit is contained in:
HELSON
2023-01-11 12:22:45 +08:00
committed by GitHub
parent a3e5496156
commit 7829aa094e
7 changed files with 56 additions and 30 deletions

View File

@@ -14,7 +14,7 @@ from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer
@@ -81,7 +81,7 @@ class ColoDDP(torch.nn.Module):
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
@@ -116,7 +116,7 @@ class ColoDDP(torch.nn.Module):
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
if p.grad.device.type != "cpu":
p.grad = p._saved_grad
@@ -232,7 +232,7 @@ class ZeroDDP(ColoDDP):
for p in param_order.generate():
assert isinstance(p, ColoParameter)
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
p.data = p.data.half()
continue
@@ -256,7 +256,7 @@ class ZeroDDP(ColoDDP):
self.chunk_manager.close_all_groups()
self._cast_buffers()
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
@@ -303,7 +303,7 @@ class ZeroDDP(ColoDDP):
def _setup_grads_ptr(self):
for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False):
if is_ddp_ignored(p):
continue
p.grad = None