[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)

* [test] refactor torch ddp checkpoint test

* [plugin] update low level zero optim checkpoint

* [plugin] update gemini optim checkpoint
This commit is contained in:
Hongxin Liu
2023-05-19 19:42:31 +08:00
committed by GitHub
parent 60e6a154bc
commit 3c07a2846e
6 changed files with 128 additions and 82 deletions

View File

@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
@@ -32,8 +32,17 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
class LowLevelZeroModel(ModelWrapper):