mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)
* [test] refactor torch ddp checkpoint test * [plugin] update low level zero optim checkpoint * [plugin] update gemini optim checkpoint
This commit is contained in:
@@ -52,8 +52,16 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
Reference in New Issue
Block a user