mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -24,6 +24,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
@@ -113,7 +114,9 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
|
||||
):
|
||||
"""Save optimizer to checkpoint but only on master process.
|
||||
|
||||
Args:
|
||||
@@ -125,9 +128,34 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||
# if only the master rank collect state_dict and save,
|
||||
# the communication on each rank would not match
|
||||
state_dict = optimizer.state_dict()
|
||||
if use_async:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict = optimizer.state_dict(pinned_state_dicts)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
use_async = checkpoint.endswith(".safetensors")
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@@ -136,6 +164,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
gather_dtensor: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
||||
@@ -161,10 +190,16 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# state_dict only provide only 'param_groups'
|
||||
state_dict = optimizer.optim.state_dict()
|
||||
# state shard would be handled by the low-level zero optimizer
|
||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
|
||||
if use_async:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
|
||||
@@ -184,7 +219,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
save_nested(f_writer, shard)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
@@ -223,7 +269,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if shard_file.endswith(".safetensors"):
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
# shard state dict
|
||||
for param_idx, state in state_dict.items():
|
||||
for k, v in state.items():
|
||||
|
Reference in New Issue
Block a user