[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
flybird11111
2024-11-18 17:52:24 +08:00
committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5
15 changed files with 374 additions and 46 deletions

View File

@@ -67,7 +67,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
@@ -157,7 +159,13 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: str,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to checkpoint but only on master process.