mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
[checkpointio]support asyncio for 3d (#6152)
* fix * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update utils.py * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from colossalai.tensor.padded_tensor import (
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -69,6 +70,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
dp_group: ProcessGroup,
|
||||
pp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
sp_group: ProcessGroup,
|
||||
zero_stage: int,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
@@ -76,9 +78,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
self.global_dp_group = dp_group
|
||||
self.pp_group = pp_group
|
||||
self.tp_group = tp_group
|
||||
self.sp_group = sp_group
|
||||
self.dp_rank = dist.get_rank(self.global_dp_group)
|
||||
self.tp_rank = dist.get_rank(self.tp_group)
|
||||
self.pp_rank = dist.get_rank(self.pp_group)
|
||||
self.sp_rank = dist.get_rank(self.sp_group)
|
||||
self.global_dp_size = dist.get_world_size(dp_group)
|
||||
self.pp_size = dist.get_world_size(pp_group)
|
||||
self.tp_size = dist.get_world_size(tp_group)
|
||||
@@ -88,7 +92,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
@staticmethod
|
||||
def _model_sharder(
|
||||
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
|
||||
model: nn.Module,
|
||||
prefix: str = "",
|
||||
keep_vars: bool = False,
|
||||
size_per_shard: int = 1024,
|
||||
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Iterator[Tuple[OrderedDict, int]]:
|
||||
# An internel method that breaks state_dict of model into shards within limited size.
|
||||
|
||||
@@ -102,6 +110,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
if is_padded_tensor(param):
|
||||
param = to_unpadded_tensor(param)
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[prefix + name].copy_(param_)
|
||||
param_ = pinned_state_dicts[prefix + name]
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
@@ -111,6 +124,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persist_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
if pinned_state_dicts is not None:
|
||||
if (prefix + name) not in pinned_state_dicts:
|
||||
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[prefix + name].copy_(buffer)
|
||||
buffer = pinned_state_dicts[prefix + name]
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
@@ -122,6 +140,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
extra_state = model.get_extra_state()
|
||||
if pinned_state_dicts is not None:
|
||||
if extra_state_key not in pinned_state_dicts:
|
||||
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[extra_state_key].copy_(extra_state)
|
||||
extra_state = pinned_state_dicts[extra_state_key]
|
||||
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
|
||||
if block is not None:
|
||||
yield block, block_size
|
||||
@@ -136,6 +159,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
dp_group: ProcessGroup,
|
||||
tp_group: ProcessGroup,
|
||||
size_per_shard: int = 1024,
|
||||
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
|
||||
):
|
||||
# An internel method that breaks state_dict of optimizer into shards within limited size.
|
||||
|
||||
@@ -153,6 +177,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
working_param = param
|
||||
|
||||
param_id = param_info["param2id"][id(working_param)]
|
||||
if pinned_state_dicts is not None:
|
||||
if param_id not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_id] = {}
|
||||
original_shape = param_info["param2shape"][id(working_param)]
|
||||
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
@@ -162,6 +189,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
tp_group=tp_group,
|
||||
use_zero=use_zero,
|
||||
inplace=False,
|
||||
pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None,
|
||||
)
|
||||
|
||||
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
|
||||
@@ -216,15 +244,31 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
control_saving = self.tp_rank == 0 and self.sp_rank == 0
|
||||
if control_saving and use_async:
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(model)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
|
||||
)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the model shards as in general checkpointIO
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
state_preprocess=False,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
@@ -234,16 +278,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
if control_saving:
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model, checkpoint)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
else:
|
||||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
@@ -259,24 +303,25 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
if use_async:
|
||||
total_size, returned_state_dict, writers = async_save_state_dict_shards(
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
n_write_entries=191,
|
||||
state_preprocess=False,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=weights_name,
|
||||
is_master=control_saving,
|
||||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
@@ -448,26 +493,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0
|
||||
|
||||
if use_async and control_saving:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
|
||||
optimizer,
|
||||
use_zero=self.use_zero,
|
||||
dp_group=self.global_dp_group,
|
||||
tp_group=self.tp_group,
|
||||
size_per_shard=size_per_shard,
|
||||
pinned_state_dicts=pinned_state_dicts,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.dp_rank == 0 and self.tp_rank == 0
|
||||
|
||||
if self.pp_size == 1:
|
||||
# When pipeline is not used, save the optimizer shards as in general checkpointIO
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
# Store param groups.
|
||||
@@ -498,18 +563,33 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Manage filenames of sharded weights and index file for each pipeline stage.
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
if not use_async:
|
||||
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
|
||||
else:
|
||||
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
|
||||
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
|
||||
save_index_file = os.path.join("tmp_index_files", save_index_file)
|
||||
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
if use_async:
|
||||
total_size, writers = async_save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
state_preprocess=True,
|
||||
)
|
||||
self.async_writers.extend(writers)
|
||||
else:
|
||||
total_size = save_state_dict_shards(
|
||||
sharded_state_dict=state_dict_shard,
|
||||
checkpoint=checkpoint,
|
||||
index_file=index_file,
|
||||
base_filename=states_name,
|
||||
is_master=control_saving,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
@@ -622,7 +702,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
if file_path.endswith(".safetensors"):
|
||||
state_dict = load_flat(file_path)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
@@ -672,7 +755,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# When pipeline is not used, let master rank directly save the collected state_dict.
|
||||
if self.tp_rank == 0:
|
||||
if use_async:
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
for name, param in state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
else:
|
||||
@@ -686,12 +777,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
for _state_dict in state_dict_list:
|
||||
complete_state_dict.update(_state_dict)
|
||||
if use_async:
|
||||
|
||||
from colossalai.utils.safetensors import move_and_save
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
if id(model) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
|
||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||
for name, param in complete_state_dict.items():
|
||||
self.pinned_state_dicts[id(model)][name].copy_(param)
|
||||
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
|
||||
writer = save(path=checkpoint, state_dict=complete_state_dict)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||
@@ -757,6 +850,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# gather complete state from tp shards & dp shards
|
||||
param_id = optimizer.param_info["param2id"][id(working_param)]
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
|
||||
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
|
||||
state,
|
||||
working_param,
|
||||
@@ -776,7 +870,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
]
|
||||
state_dict = {"param_groups": param_groups, "state": local_states}
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[k]
|
||||
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
else:
|
||||
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
|
||||
states_list = [None for _ in range(self.pp_size)]
|
||||
@@ -792,7 +898,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_dict = {"param_groups": param_groups, "state": dict()}
|
||||
for _states in states_list:
|
||||
state_dict["state"].update(_states)
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import save
|
||||
|
||||
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
|
||||
for k, v in flatten_state_dict.items():
|
||||
self.pinned_state_dicts[k].copy_(v)
|
||||
flatten_state_dict[k] = self.pinned_state_dicts[k]
|
||||
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
|
||||
self.async_writers.append(writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
@@ -818,7 +936,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
if checkpoint.endswith(".safetensors"):
|
||||
state_dict = load_flat(checkpoint)
|
||||
else:
|
||||
state_dict = load_state_dict(checkpoint)
|
||||
|
||||
# Load param_groups.
|
||||
updated_groups = []
|
||||
@@ -872,6 +993,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
use_zero: bool,
|
||||
inplace: bool,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
With given parameter and its optimizer states, gather the complete optimizer state for saving.
|
||||
@@ -895,6 +1017,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
state_ = state if inplace else copy.deepcopy(state)
|
||||
|
||||
for k, v in state_.items():
|
||||
if v is None:
|
||||
continue
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# First gather Zero shards.
|
||||
if use_zero:
|
||||
@@ -915,7 +1039,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
|
||||
v = to_unpadded_tensor(v)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
if pinned_state_dicts is not None:
|
||||
if k not in pinned_state_dicts:
|
||||
pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu")
|
||||
pinned_state_dicts[k].copy_(v)
|
||||
state_[k] = pinned_state_dicts[k]
|
||||
else:
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
|
Reference in New Issue
Block a user