mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polish
This commit is contained in:
@@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
norm_type=norm_type)
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
return self.stage == 1
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
@@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return LowLevelZeroCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.optim.no_sync()
|
||||
|
Reference in New Issue
Block a user