mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[checkpointio] support async model save (#6131)
* [checkpointio] support async model save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -33,13 +33,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(
|
||||
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
|
||||
super().save_unsharded_model(
|
||||
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
|
||||
)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
@@ -71,6 +75,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
@@ -78,7 +83,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(
|
||||
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
model.unwrap(),
|
||||
checkpoint_path,
|
||||
gather_dtensor,
|
||||
prefix,
|
||||
max_shard_size,
|
||||
use_safetensors,
|
||||
use_async=use_async,
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
|
Reference in New Issue
Block a user