mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -213,6 +213,7 @@ class CheckpointIO(ABC):
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||
@@ -229,11 +230,12 @@ class CheckpointIO(ABC):
|
||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
self.save_sharded_optimizer(
|
||||
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||
)
|
||||
else:
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for model loading/saving implementation
|
||||
@@ -326,7 +328,13 @@ class CheckpointIO(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
@@ -340,7 +348,9 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer to unsharded checkpoint.
|
||||
|
||||
|
Reference in New Issue
Block a user