mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
@@ -14,7 +14,7 @@ from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, is_ddp_ignored
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
from .reducer import Reducer
|
||||
@@ -81,7 +81,7 @@ class ColoDDP(torch.nn.Module):
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.requires_grad:
|
||||
p.register_hook(partial(self.grad_handle, p))
|
||||
@@ -116,7 +116,7 @@ class ColoDDP(torch.nn.Module):
|
||||
if self.rebuild_bucket:
|
||||
self.reducer.free()
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
if p.grad.device.type != "cpu":
|
||||
p.grad = p._saved_grad
|
||||
@@ -232,7 +232,7 @@ class ZeroDDP(ColoDDP):
|
||||
for p in param_order.generate():
|
||||
assert isinstance(p, ColoParameter)
|
||||
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.half()
|
||||
continue
|
||||
|
||||
@@ -256,7 +256,7 @@ class ZeroDDP(ColoDDP):
|
||||
self.chunk_manager.close_all_groups()
|
||||
self._cast_buffers()
|
||||
|
||||
params_list = [p for p in param_order.generate() if not getattr(p, '_ddp_to_ignore', False)]
|
||||
params_list = [p for p in param_order.generate() if not is_ddp_ignored(p)]
|
||||
for p, fp32_p in zip(params_list, self.fp32_params):
|
||||
chunk_16 = self.chunk_manager.get_chunk(p)
|
||||
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
|
||||
@@ -303,7 +303,7 @@ class ZeroDDP(ColoDDP):
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
if is_ddp_ignored(p):
|
||||
continue
|
||||
p.grad = None
|
||||
|
||||
|
Reference in New Issue
Block a user