mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[checkpointio] General Checkpointing of Sharded Optimizers (#3984)
This commit is contained in:
@@ -32,7 +32,6 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
@@ -54,11 +53,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, variant, max_shard_size, use_safetensors)
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
|
||||
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
Reference in New Issue
Block a user