[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
flybird11111
2024-11-18 17:52:24 +08:00
committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5
15 changed files with 374 additions and 46 deletions

View File

@@ -52,7 +52,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to checkpoint but only on master process.
"""
@@ -113,13 +115,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
def load_sharded_optimizer(
self,