mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
This commit is contained in:
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
@@ -153,18 +153,20 @@ class Booster:
|
||||
# return loss or outputs if needed
|
||||
pass
|
||||
|
||||
def no_sync(self, model: nn.Module) -> contextmanager:
|
||||
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
||||
"""Context manager to disable gradient synchronization across DP process groups.
|
||||
Support torch DDP and Low Level ZeRO-1 for now.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be disabled gradient synchronization.
|
||||
model (nn.Module): The model to be disabled gradient synchronization, for DDP
|
||||
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
|
||||
|
||||
Returns:
|
||||
contextmanager: Context to disable gradient synchronization.
|
||||
"""
|
||||
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
"""Load model from checkpoint.
|
||||
|
@@ -408,5 +408,5 @@ class GeminiPlugin(DPPluginBase):
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return GeminiCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
@@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
norm_type=norm_type)
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
return self.stage == 1
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
@@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return LowLevelZeroCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.optim.no_sync()
|
||||
|
@@ -61,7 +61,7 @@ class Plugin(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
|
@@ -168,6 +168,6 @@ class TorchDDPPlugin(DPPluginBase):
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchDDPCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
||||
return model.module.no_sync()
|
||||
|
@@ -177,7 +177,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||
def support_no_sync(self) -> bool:
|
||||
False
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
|
Reference in New Issue
Block a user