[booster] fixed the torch ddp plugin with the new checkpoint api (#3442)

This commit is contained in:
Frank Lee
2023-04-06 09:43:51 +08:00
committed by GitHub
parent 8f740deb53
commit 7d8d825681
3 changed files with 11 additions and 10 deletions

View File

@@ -33,20 +33,20 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint)
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""