[checkpoint] refactored the API and added safetensors support (#3427)

* [checkpoint] refactored the API and added safetensors support

* polish code
This commit is contained in:
Frank Lee
2023-04-04 15:23:01 +08:00
committed by GitHub
parent 26b7aac0be
commit 1beb85cc25
9 changed files with 579 additions and 280 deletions

View File

@@ -33,7 +33,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
"""
Save model to checkpoint but only on master process.
"""
@@ -41,7 +41,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""