mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-24 22:42:15 +00:00
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
This commit is contained in:
parent
c6ab96983a
commit
79cf1b5f33
@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
|
||||||
from .accelerator import Accelerator
|
from .accelerator import Accelerator
|
||||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||||
@ -153,18 +153,20 @@ class Booster:
|
|||||||
# return loss or outputs if needed
|
# return loss or outputs if needed
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module) -> contextmanager:
|
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
||||||
"""Context manager to disable gradient synchronization across DP process groups.
|
"""Context manager to disable gradient synchronization across DP process groups.
|
||||||
|
Support torch DDP and Low Level ZeRO-1 for now.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module): The model to be disabled gradient synchronization.
|
model (nn.Module): The model to be disabled gradient synchronization, for DDP
|
||||||
|
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
contextmanager: Context to disable gradient synchronization.
|
contextmanager: Context to disable gradient synchronization.
|
||||||
"""
|
"""
|
||||||
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
||||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||||
return self.plugin.no_sync(model)
|
return self.plugin.no_sync(model, optimizer)
|
||||||
|
|
||||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
@ -408,5 +408,5 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return GeminiCheckpointIO()
|
return GeminiCheckpointIO()
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
norm_type=norm_type)
|
norm_type=norm_type)
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
|
|
||||||
|
# set class name with stage, for better error message
|
||||||
|
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||||
|
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return False
|
return self.stage == 1
|
||||||
|
|
||||||
def control_precision(self) -> bool:
|
def control_precision(self) -> bool:
|
||||||
return True
|
return True
|
||||||
@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return LowLevelZeroCheckpointIO()
|
return LowLevelZeroCheckpointIO()
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
raise NotImplementedError
|
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||||
|
return optimizer.optim.no_sync()
|
||||||
|
@ -61,7 +61,7 @@ class Plugin(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
"""
|
"""
|
||||||
Context manager to disable gradient synchronization.
|
Context manager to disable gradient synchronization.
|
||||||
"""
|
"""
|
||||||
|
@ -168,6 +168,6 @@ class TorchDDPPlugin(DPPluginBase):
|
|||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return TorchDDPCheckpointIO()
|
return TorchDDPCheckpointIO()
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
||||||
return model.module.no_sync()
|
return model.module.no_sync()
|
||||||
|
@ -177,7 +177,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
False
|
False
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||||
|
|
||||||
def control_precision(self) -> bool:
|
def control_precision(self) -> bool:
|
||||||
|
@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||||||
)
|
)
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||||
from colossalai.utils import conditional_context
|
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
overlap_communication: bool = False,
|
overlap_communication: bool = False,
|
||||||
partition_grad: bool = False, # stage 2 flag
|
partition_grad: bool = False, # stage 2 flag
|
||||||
cpu_offload: bool = False, # cpu offload
|
cpu_offload: bool = False, # cpu offload
|
||||||
grad_accumulate_interval: int = 1,
|
|
||||||
forced_dtype: Optional[torch.dtype] = None):
|
forced_dtype: Optional[torch.dtype] = None):
|
||||||
|
|
||||||
assert not (partition_grad and grad_accumulate_interval > 1), \
|
# TODO:
|
||||||
"gradient accumulation is not compatible with ZeRO-2"
|
# 1. process group api
|
||||||
|
# 2. checkpoint IO
|
||||||
|
|
||||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
|
|
||||||
# grad accumulation
|
# grad accumulation
|
||||||
self.require_grad_sync = True
|
self.require_grad_sync = True
|
||||||
self._accumulate_intervel = grad_accumulate_interval
|
|
||||||
self._accumulate_step = 0
|
|
||||||
|
|
||||||
colo_pg = self._search_colo_process_group()
|
colo_pg = self._search_colo_process_group()
|
||||||
if isinstance(colo_pg, ProcessGroup):
|
if isinstance(colo_pg, ProcessGroup):
|
||||||
@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
################################
|
################################
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False):
|
def backward(self, loss, retain_graph=False):
|
||||||
|
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||||
|
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||||
|
|
||||||
if self.mixed_precision_mixin is not None:
|
if self.mixed_precision_mixin is not None:
|
||||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||||
|
|
||||||
self._accumulate_step += 1
|
loss.backward(retain_graph=retain_graph)
|
||||||
no_sync = self._accumulate_step < self._accumulate_intervel
|
|
||||||
with conditional_context(self.no_sync(), enable=no_sync):
|
|
||||||
loss.backward(retain_graph=retain_graph)
|
|
||||||
|
|
||||||
if no_sync:
|
if not self.require_grad_sync:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._reduce_grad(self._partition_grads)
|
self._reduce_grad(self._partition_grads)
|
||||||
@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
assert closure is None, 'closure is not supported by step()'
|
assert closure is None, 'closure is not supported by step()'
|
||||||
|
|
||||||
if not self._accumulate_step == self._accumulate_intervel:
|
if not self.require_grad_sync:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||||
@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
if self._verbose:
|
if self._verbose:
|
||||||
self._logger.info(f'Found overflow. Skip step')
|
self._logger.info(f'Found overflow. Skip step')
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
self._accumulate_step -= 1
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# record all grads for unscale and clip
|
# record all grads for unscale and clip
|
||||||
@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||||||
|
|
||||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||||
|
|
||||||
# reset accumulate step
|
|
||||||
self._accumulate_step = 0
|
|
||||||
|
|
||||||
#############################
|
#############################
|
||||||
# Mixed Precision Utilities #
|
# Mixed Precision Utilities #
|
||||||
#############################
|
#############################
|
||||||
|
@ -9,6 +9,7 @@ from torch.testing import assert_close
|
|||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.testing import spawn
|
from colossalai.testing import spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
|
from colossalai.utils import conditional_context
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
|
|
||||||
@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
|
|||||||
overlap_communication=True,
|
overlap_communication=True,
|
||||||
initial_scale=32,
|
initial_scale=32,
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0,
|
||||||
grad_accumulate_interval=2,
|
|
||||||
verbose=True)
|
verbose=True)
|
||||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||||
overlap_communication=True,
|
overlap_communication=True,
|
||||||
partition_grad=True,
|
partition_grad=True,
|
||||||
initial_scale=32,
|
initial_scale=32,
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0)
|
||||||
grad_accumulate_interval=2)
|
|
||||||
# create data
|
# create data
|
||||||
seed_all(2021 + local_rank)
|
seed_all(2021 + local_rank)
|
||||||
input_data1 = torch.randn(32, 128).cuda()
|
input_data1 = torch.randn(32, 128).cuda()
|
||||||
@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc():
|
|||||||
assert torch.equal(zero1_output, zero2_output)
|
assert torch.equal(zero1_output, zero2_output)
|
||||||
|
|
||||||
# zero-dp backward
|
# zero-dp backward
|
||||||
zero1_optimizer.backward(zero1_output.sum().float())
|
no_sync = number == 0
|
||||||
zero2_optimizer.backward(zero2_output.sum().float())
|
with conditional_context(zero1_optimizer.no_sync(), no_sync):
|
||||||
|
zero1_optimizer.backward(zero1_output.sum().float())
|
||||||
|
with conditional_context(zero2_optimizer.no_sync(), no_sync):
|
||||||
|
zero2_optimizer.backward(zero2_output.sum().float())
|
||||||
|
|
||||||
if check_flag:
|
if check_flag:
|
||||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||||
@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
|
|||||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
reduce_bucket_size=262144,
|
reduce_bucket_size=262144,
|
||||||
clip_grad_norm=1.0,
|
clip_grad_norm=1.0)
|
||||||
grad_accumulate_interval=2)
|
|
||||||
|
|
||||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||||
|
|
||||||
@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
|
|||||||
input_data2 = torch.randn(32, 128).cuda()
|
input_data2 = torch.randn(32, 128).cuda()
|
||||||
|
|
||||||
def fwd_bwd_func(number, cur_data, check_flag):
|
def fwd_bwd_func(number, cur_data, check_flag):
|
||||||
# zero-dp forward
|
|
||||||
zero_output = zero_model(cur_data)
|
|
||||||
|
|
||||||
# torch-ddp forward
|
no_sync = number == 0
|
||||||
|
# zero1 fwd and bwd
|
||||||
|
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
||||||
|
zero_output = zero_model(cur_data)
|
||||||
|
zero_optimizer.backward(zero_output.sum().float())
|
||||||
|
|
||||||
# zero-dp backward
|
# torch-ddp fwd and bwd
|
||||||
zero_optimizer.backward(zero_output.sum().float())
|
with conditional_context(torch_model.no_sync(), no_sync):
|
||||||
# torch-ddp backward
|
|
||||||
if number < 1:
|
|
||||||
with torch_model.no_sync():
|
|
||||||
torch_output = torch_model(cur_data)
|
|
||||||
assert torch.equal(zero_output, torch_output)
|
|
||||||
torch_output.sum().backward()
|
|
||||||
else:
|
|
||||||
torch_output = torch_model(cur_data)
|
torch_output = torch_model(cur_data)
|
||||||
assert torch.equal(zero_output, torch_output)
|
assert torch.equal(zero_output, torch_output)
|
||||||
torch_output.sum().backward()
|
torch_output.sum().backward()
|
||||||
@ -133,7 +129,6 @@ def exam_zero_1_grad_acc():
|
|||||||
if check_flag:
|
if check_flag:
|
||||||
# check grad
|
# check grad
|
||||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
|
||||||
assert torch.equal(p.grad, z1p.grad)
|
assert torch.equal(p.grad, z1p.grad)
|
||||||
|
|
||||||
fwd_bwd_func(0, input_data1, True)
|
fwd_bwd_func(0, input_data1, True)
|
||||||
|
Loading…
Reference in New Issue
Block a user