[checkpointio] support async model save (#6131)

* [checkpointio] support async model save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-14 11:38:10 +08:00
parent 5a03d2696d
commit d4a436051d
7 changed files with 209 additions and 28 deletions

View File

@@ -33,13 +33,17 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(
model.unwrap(), checkpoint, gather_dtensor, use_safetensors, use_async=use_async
)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
@@ -71,6 +75,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
"""
Save model to checkpoint but only on master process.
@@ -78,7 +83,13 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model.unwrap(),
checkpoint_path,
gather_dtensor,
prefix,
max_shard_size,
use_safetensors,
use_async=use_async,
)
def load_sharded_model(