mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[checkpointio] sharded optimizer checkpoint for DDP plugin (#4002)
This commit is contained in:
@@ -52,7 +52,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
@@ -62,8 +62,12 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
|
||||
def save_sharded_optimier(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str,
|
||||
size_per_shard: int):
|
||||
def save_sharded_optimizer(self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user