[zero]support no_sync method for zero1 plugin (#4138)

* support no sync for zero1 plugin

* polish

* polish
This commit is contained in:
LuGY
2023-07-04 12:00:33 +08:00
committed by Hongxin Liu
parent c6ab96983a
commit 79cf1b5f33
8 changed files with 45 additions and 49 deletions

View File

@@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
norm_type=norm_type)
self.verbose = verbose
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
def support_no_sync(self) -> bool:
return False
return self.stage == 1
def control_precision(self) -> bool:
return True
@@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()