mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)
* [test] refactor torch ddp checkpoint test * [plugin] update low level zero optim checkpoint * [plugin] update gemini optim checkpoint
This commit is contained in:
@@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
@@ -32,8 +32,17 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
|
||||
warnings.warn(
|
||||
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper):
|
||||
|
Reference in New Issue
Block a user