[async io]supoort async io (#6137)

* support async optimizer save/load

* fix

* fix

* support pin mem

* Update low_level_zero_plugin.py

* fix

* fix

* fix

* fix

* fix
This commit is contained in:
flybird11111
2024-11-18 17:52:24 +08:00
committed by Hongxin Liu
parent b90835bd32
commit eb69e640e5
15 changed files with 374 additions and 46 deletions

View File

@@ -213,6 +213,7 @@ class CheckpointIO(ABC):
gather_dtensor=True,
prefix: str = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
@@ -229,11 +230,12 @@ class CheckpointIO(ABC):
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
"""
if shard:
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
self.save_sharded_optimizer(
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
)
else:
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
# ========================================================
# Abstract methods for model loading/saving implementation
@@ -326,7 +328,13 @@ class CheckpointIO(ABC):
@abstractmethod
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self,
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save optimizer to sharded checkpoint.
@@ -340,7 +348,9 @@ class CheckpointIO(ABC):
"""
@abstractmethod
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer to unsharded checkpoint.

View File

@@ -98,6 +98,7 @@ class GeneralCheckpointIO(CheckpointIO):
gather_dtensor: bool,
prefix: str,
size_per_shard: int,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@@ -155,6 +156,7 @@ class GeneralCheckpointIO(CheckpointIO):
optimizer: Optimizer,
checkpoint: Path,
gather_dtensor: bool,
use_async: bool = False,
):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)

View File

@@ -416,6 +416,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@@ -725,7 +726,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Update master params if mixed-precision training is enabled.
model_before_wrapping.update_master_params()
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
):
"""
Save optimizer state dict to a file with given path.

View File

@@ -369,6 +369,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_async: bool = False,
):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
@@ -729,7 +730,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
dist.barrier()
# Copied from colossalai.moe
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(
self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool,
use_async: bool = False,
):
"""
Save optimizer state dict to a file with given path.

View File

@@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
STATES_NAME = "pytorch_optim.bin"
SAFE_STATE_NAME = "optimizer.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
GROUP_FILE_NAME = "pytorch_optim_group.bin"
# ======================================
@@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
return weights_name, save_index_file
def get_optimizer_base_filenames(prefix: str = None):
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
"""
generate base optimizer state filenames
"""
states_name = STATES_NAME
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
states_name = add_prefix(states_name, prefix)
save_index_file = STATES_INDEX_NAME
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
save_index_file = add_prefix(save_index_file, prefix)
param_group_file = GROUP_FILE_NAME