mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[checkpoint] refactored the API and added safetensors support (#3427)
* [checkpoint] refactored the API and added safetensors support * polish code
This commit is contained in:
@@ -33,7 +33,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str):
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
@@ -41,7 +41,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
|
Reference in New Issue
Block a user