mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
This commit is contained in:
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
@@ -153,18 +153,20 @@ class Booster:
|
||||
# return loss or outputs if needed
|
||||
pass
|
||||
|
||||
def no_sync(self, model: nn.Module) -> contextmanager:
|
||||
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
||||
"""Context manager to disable gradient synchronization across DP process groups.
|
||||
Support torch DDP and Low Level ZeRO-1 for now.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be disabled gradient synchronization.
|
||||
model (nn.Module): The model to be disabled gradient synchronization, for DDP
|
||||
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
|
||||
|
||||
Returns:
|
||||
contextmanager: Context to disable gradient synchronization.
|
||||
"""
|
||||
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
"""Load model from checkpoint.
|
||||
|
Reference in New Issue
Block a user