[checkpointio]support asyncio for 3d (#6152)

* fix

* fix

* fix

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update utils.py

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
flybird11111
2024-12-23 10:24:22 +08:00
committed by GitHub
parent aaafb38851
commit 130229fdcb
17 changed files with 776 additions and 188 deletions

View File

@@ -22,6 +22,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -69,6 +70,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
sp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True,
) -> None:
@@ -76,9 +78,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.global_dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.sp_group = sp_group
self.dp_rank = dist.get_rank(self.global_dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.sp_rank = dist.get_rank(self.sp_group)
self.global_dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
@@ -88,7 +92,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
@staticmethod
def _model_sharder(
model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024
model: nn.Module,
prefix: str = "",
keep_vars: bool = False,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
@@ -102,6 +110,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(param_)
param_ = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
@@ -111,6 +124,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for name, buf in model.named_buffers():
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[prefix + name].copy_(buffer)
buffer = pinned_state_dicts[prefix + name]
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
@@ -122,6 +140,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
is not torch.nn.Module.get_extra_state
):
extra_state = model.get_extra_state()
if pinned_state_dicts is not None:
if extra_state_key not in pinned_state_dicts:
pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu")
pinned_state_dicts[extra_state_key].copy_(extra_state)
extra_state = pinned_state_dicts[extra_state_key]
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
@@ -136,6 +159,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dp_group: ProcessGroup,
tp_group: ProcessGroup,
size_per_shard: int = 1024,
pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None,
):
# An internel method that breaks state_dict of optimizer into shards within limited size.
@@ -153,6 +177,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
working_param = param
param_id = param_info["param2id"][id(working_param)]
if pinned_state_dicts is not None:
if param_id not in pinned_state_dicts:
pinned_state_dicts[param_id] = {}
original_shape = param_info["param2shape"][id(working_param)]
state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
@@ -162,6 +189,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=tp_group,
use_zero=use_zero,
inplace=False,
pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None,
)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
@@ -216,15 +244,31 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async:
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts
)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async)
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
state_preprocess=False,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
@@ -234,16 +278,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
is_master=control_saving,
use_safetensors=use_safetensors,
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose and self.coordinator.is_master():
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
# When pipeline is used, each stage produces its own shard files and index files.
@@ -259,24 +303,25 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
if use_async:
total_size, returned_state_dict, writers = async_save_state_dict_shards(
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_pp_format=True,
n_write_entries=191,
state_preprocess=False,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert (
@@ -448,26 +493,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0
if use_async and control_saving:
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(optimizer)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.global_dp_group,
tp_group=self.tp_group,
size_per_shard=size_per_shard,
pinned_state_dicts=pinned_state_dicts,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.dp_rank == 0 and self.tp_rank == 0
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
)
if control_saving:
# Store param groups.
@@ -498,18 +563,33 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
if not use_async:
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
else:
states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if use_async:
total_size, writers = async_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
state_preprocess=True,
)
self.async_writers.extend(writers)
else:
total_size = save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True,
)
if control_saving:
assert (
@@ -622,7 +702,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
if file_path.endswith(".safetensors"):
state_dict = load_flat(file_path)
else:
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
@@ -672,7 +755,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When pipeline is not used, let master rank directly save the collected state_dict.
if self.tp_rank == 0:
if use_async:
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async)
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors)
else:
@@ -686,12 +777,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for _state_dict in state_dict_list:
complete_state_dict.update(_state_dict)
if use_async:
from colossalai.utils.safetensors import move_and_save
from colossalai.utils.safetensors import save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
@@ -757,6 +850,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# gather complete state from tp shards & dp shards
param_id = optimizer.param_info["param2id"][id(working_param)]
original_shape = optimizer.param_info["param2shape"][id(working_param)]
local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state(
state,
working_param,
@@ -776,7 +870,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
]
state_dict = {"param_groups": param_groups, "state": local_states}
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[k]
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
else:
# When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict.
states_list = [None for _ in range(self.pp_size)]
@@ -792,7 +898,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_dict = {"param_groups": param_groups, "state": dict()}
for _states in states_list:
state_dict["state"].update(_states)
save_state_dict(state_dict, checkpoint, use_safetensors=False)
if use_async:
from colossalai.utils.safetensors import save
flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict)
if id(optimizer) not in self.pinned_state_dicts:
self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict)
for k, v in flatten_state_dict.items():
self.pinned_state_dicts[k].copy_(v)
flatten_state_dict[k] = self.pinned_state_dicts[k]
writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata)
self.async_writers.append(writer)
else:
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
@@ -818,7 +936,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
# Complete optimizer state_dict loaded from checkpoint, need to be processed later.
state_dict = load_state_dict(checkpoint)
if checkpoint.endswith(".safetensors"):
state_dict = load_flat(checkpoint)
else:
state_dict = load_state_dict(checkpoint)
# Load param_groups.
updated_groups = []
@@ -872,6 +993,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_zero: bool,
inplace: bool,
device: torch.device = torch.device("cpu"),
pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None,
) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
@@ -895,6 +1017,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if v is None:
continue
if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards.
if use_zero:
@@ -915,7 +1039,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)
state_[k] = v.detach().clone().to(device)
if pinned_state_dicts is not None:
if k not in pinned_state_dicts:
pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu")
pinned_state_dicts[k].copy_(v)
state_[k] = pinned_state_dicts[k]
else:
state_[k] = v.detach().clone().to(device)
return state_