[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)

* [test] refactor torch ddp checkpoint test

* [plugin] update low level zero optim checkpoint

* [plugin] update gemini optim checkpoint
This commit is contained in:
Hongxin Liu
2023-05-19 19:42:31 +08:00
committed by GitHub
parent 60e6a154bc
commit 3c07a2846e
6 changed files with 128 additions and 82 deletions

View File

@@ -52,8 +52,16 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save model to checkpoint but only on master process.