mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -213,6 +213,7 @@ class CheckpointIO(ABC):
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||
@@ -229,11 +230,12 @@ class CheckpointIO(ABC):
|
||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
self.save_sharded_optimizer(
|
||||
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||
)
|
||||
else:
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for model loading/saving implementation
|
||||
@@ -326,7 +328,13 @@ class CheckpointIO(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
@@ -340,7 +348,9 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer to unsharded checkpoint.
|
||||
|
||||
|
@@ -98,6 +98,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
@@ -155,6 +156,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
|
@@ -416,6 +416,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
@@ -725,7 +726,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
|
@@ -369,6 +369,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
@@ -729,7 +730,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
|
@@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
STATES_NAME = "pytorch_optim.bin"
|
||||
SAFE_STATE_NAME = "optimizer.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
|
||||
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||
|
||||
# ======================================
|
||||
@@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
return weights_name, save_index_file
|
||||
|
||||
|
||||
def get_optimizer_base_filenames(prefix: str = None):
|
||||
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base optimizer state filenames
|
||||
"""
|
||||
states_name = STATES_NAME
|
||||
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
|
||||
states_name = add_prefix(states_name, prefix)
|
||||
|
||||
save_index_file = STATES_INDEX_NAME
|
||||
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
param_group_file = GROUP_FILE_NAME
|
||||
|
Reference in New Issue
Block a user