[checkpointio] support async model save (#6131)

* [checkpointio] support async model save

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-11-14 11:38:10 +08:00
parent 5a03d2696d
commit d4a436051d
7 changed files with 209 additions and 28 deletions

View File

@@ -259,10 +259,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
def save_sharded_model(
self,
@@ -272,11 +274,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
use_async: bool = False,
):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
model._force_wait_all_gather()
return super().save_sharded_model(
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async=use_async
)
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):