[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
flybird11111
2024-11-18 17:52:24 +08:00
committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5
15 changed files with 374 additions and 46 deletions

View File

@@ -94,7 +94,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
@@ -178,7 +180,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: GeminiOptimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer state dict to checkpoint folder.