diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 30c1257ef..441670a0a 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -17,6 +17,8 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.accelerator import get_accelerator from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( + async_save_state_dict_shards, + create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, load_shard_state_dict, @@ -28,6 +30,7 @@ from colossalai.cluster import DistCoordinator, ProcessGroupMesh from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.logging import get_dist_logger from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.utils.safetensors import load_flat from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero.gemini.memory_tracer import MemStats @@ -82,7 +85,15 @@ class GeminiCheckpointIO(GeneralCheckpointIO): state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for k, v in state_dict.items(): + self.pinned_state_dicts[id(model)][k].copy_(v) + state_dict[k] = self.pinned_state_dicts[id(model)][k] + writer = save(checkpoint, state_dict) + self.async_writers.append(writer) else: save_state_dict(state_dict, checkpoint, use_safetensors) @@ -106,7 +117,19 @@ class GeminiCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!" state_dict = optimizer.state_dict() if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from colossalai.utils.safetensors import _flatten_optim_state_dict, save + + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[id(optimizer)][k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] + writer = save(checkpoint, flatten_state_dict, metadata) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str): """ @@ -137,17 +160,29 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Path(checkpoint_path).mkdir(parents=True, exist_ok=True) - state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True) + if use_async and self.coordinator.is_master(): + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = model.state_dict_shard( + max_shard_size=max_shard_size, only_rank_0=True, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) # Save shards of optimizer states. is_master = self.coordinator.is_master() if use_async: - super().save_sharded_model( - model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors, use_async + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=is_master, ) - + self.async_writers.extend(writers) else: total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -158,17 +193,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO): use_safetensors=use_safetensors, ) - # only save the index file on the master rank - if self.coordinator.is_master(): - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model.unwrap(), checkpoint_path) - self.logger.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}.", - ranks=[0], - ) + # only save the index file on the master rank + if self.coordinator.is_master(): + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model.unwrap(), checkpoint_path) + self.logger.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}.", + ranks=[0], + ) def load_sharded_model( self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False @@ -201,7 +236,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO): Path(checkpoint).mkdir(parents=True, exist_ok=True) # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) @@ -212,17 +247,36 @@ class GeminiCheckpointIO(GeneralCheckpointIO): torch.save(param_groups, group_file_path) # States are broken into shards within max_shard_size. - state_dict_shard = optimizer.state_shard(prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True) + if use_async and self.coordinator.is_master(): + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + state_dict_shard = optimizer.state_shard( + prefix=prefix, max_shard_size=size_per_shard, only_rank_0=True, pinned_state_dicts=pinned_state_dicts + ) # Save shards of optimizer states. - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=self.coordinator.is_master(), - use_safetensors=False, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) # Wrap up index file. Only save it on master rank. if self.coordinator.is_master(): @@ -264,7 +318,10 @@ class GeminiCheckpointIO(GeneralCheckpointIO): # Load optimizer states from shard files under checkpoint path. # For each file, only load the states managed by current process. for shard_file in checkpoint_files: - state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + state_dict_shard = load_flat(shard_file) + else: + state_dict_shard = load_shard_state_dict(Path(shard_file), use_safetensors=False) optimizer.load_param_states(state_dict_shard) del state_dict_shard gc.collect() diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 79c9379cc..bc9425a0b 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1488,7 +1488,7 @@ class HybridParallelPlugin(PipelinePluginBase): ) def get_checkpoint_io(self) -> CheckpointIO: - return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage) + return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.sp_group, self.zero_stage) def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]: assert ( diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 96531a04f..6937b8d74 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -404,7 +404,13 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): def get_checkpoint_io(self) -> MoECheckpointIO: return MoECheckpointIO( - self.dp_group, self.pp_group, self.tp_group, self.ep_group, self.moe_dp_group, self.zero_stage + self.dp_group, + self.pp_group, + self.tp_group, + self.sp_group, + self.ep_group, + self.moe_dp_group, + self.zero_stage, ) def configure( diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index 07be5b051..90d406eef 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -60,7 +60,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO): """ assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!" if self.coordinator.is_master(): - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) + super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async) def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): """ diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index b80d6d4b6..1d792757b 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple import torch import torch.nn as nn @@ -26,9 +26,11 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO, utils +from colossalai.checkpoint_io.utils import async_save_state_dict_shards, create_pinned_state_dict from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.logging import get_dist_logger +from colossalai.utils.safetensors import load_flat from .dp_plugin_base import DPPluginBase @@ -49,8 +51,36 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path): assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!" - checkpoint = utils.load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + checkpoint = load_flat(checkpoint, seperator=".") + else: + checkpoint = utils.load_state_dict(checkpoint) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=False) + start_index = 0 + id2name = {} + + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + start_num = len(id2name) + id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name}) + end_num = len(id2name) + start_index += end_num - start_num + + for g in full_optimizer_state["param_groups"]: + get_index_mapping(g) + + new_state = {} + for key, value in checkpoint["state"].items(): + new_state[id2name[int(key)]] = value + checkpoint["state"] = new_state + for g in checkpoint["param_groups"]: + new_group = [] + for param_id in g["params"]: + new_group.append(id2name[param_id]) + g["params"] = new_group + sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model) optimizer.load_state_dict(sharded_osd) @@ -65,7 +95,21 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg): full_model_state = model.state_dict() - utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors) + if self.coordinator.is_master(): + if use_async: + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + for k, v in full_model_state.items(): + self.pinned_state_dicts[id(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[id(model)][k] + writer = save(checkpoint, full_model_state) + self.async_writers.append(writer) + else: + utils.save_state_dict( + full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors + ) def save_unsharded_optimizer( self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False @@ -75,8 +119,43 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): """ assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!" fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True) - utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) + + if self.coordinator.is_master(): + + # Save order indices instead of Tensors + name2id: Dict[str, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id}) + packed["params"] = [name2id[p] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in full_optimizer_state["param_groups"]] + full_optimizer_state["param_groups"] = param_groups + new_state = {} + for key, value in full_optimizer_state["state"].items(): + new_state[name2id[key]] = value + full_optimizer_state["state"] = new_state + + if use_async: + from colossalai.utils.safetensors import _flatten_optim_state_dict, save + + flatten_state_dict, metadata = _flatten_optim_state_dict(full_optimizer_state, seperator=".") + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[id(optimizer)][k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[id(optimizer)][k] + writer = save(checkpoint, state_dict=flatten_state_dict, metadata=metadata) + self.async_writers.append(writer) + else: + utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False) def save_sharded_model( self, @@ -102,20 +181,38 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ): state_dict = model.unwrap().state_dict() - state_dict_shard = utils.shard_model_checkpoint(state_dict, max_shard_size=size_per_shard) + if use_async and self.coordinator.is_master(): + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = utils.shard_model_checkpoint( + state_dict, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = utils.get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint_path) # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint_path, - index_file=index_file, - base_filename=weights_name, - is_master=self.coordinator.is_master(), - use_safetensors=use_safetensors, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + ) + self.async_writers.extend(writers) + else: + total_size = utils.save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint_path, + index_file=index_file, + base_filename=weights_name, + is_master=self.coordinator.is_master(), + use_safetensors=use_safetensors, + ) # only save the index file on the master rank if self.coordinator.is_master(): @@ -188,26 +285,66 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): ) if self.coordinator.is_master(): + + # Save order indices instead of Tensors + name2id: Dict[str, int] = {} + start_index = 0 + + def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + name2id.update({p: i for i, p in enumerate(group["params"], start_index) if p not in name2id}) + packed["params"] = [name2id[p] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + param_groups = [pack_group(g) for g in fsdp_optim_state["param_groups"]] + fsdp_optim_state["param_groups"] = param_groups + new_state = {} + for key, value in fsdp_optim_state["state"].items(): + new_state[name2id[key]] = value + fsdp_optim_state["state"] = new_state + # Preparing file paths and index file. - states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = utils.get_optimizer_base_filenames( + prefix, use_safetensors=use_async + ) index_file = CheckpointIndexFile(checkpoint) index_file.append_meta_data("param_groups", param_group_file) group_file_path = os.path.join(checkpoint, param_group_file) utils.save_param_groups(fsdp_optim_state, group_file_path) - sharded_state = utils.shard_optimizer_checkpoint(fsdp_optim_state, max_shard_size=size_per_shard) - + if use_async: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None + sharded_state = utils.shard_optimizer_checkpoint( + fsdp_optim_state, max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = utils.save_state_dict_shards( - sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=self.coordinator.is_master(), - use_safetensors=False, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = utils.save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=self.coordinator.is_master(), + use_safetensors=False, + ) index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) @@ -239,11 +376,39 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO): fsdp_optim_state = {} checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: - state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + state_dict_shard = load_flat(shard_file, seperator=".") + else: + state_dict_shard = utils.load_shard_state_dict(Path(shard_file), use_safetensors=False) fsdp_optim_state.update(state_dict_shard) fsdp_optim_dict = dict(state=fsdp_optim_state, param_groups=saved_param_groups) + fsdp_model = optimizer.unwrap_model() + full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model.unwrap(), optim=optimizer, rank0_only=False) + start_index = 0 + id2name = {} + + def get_index_mapping(group: Dict[str, Any]) -> Dict[str, Any]: + nonlocal start_index + start_num = len(id2name) + id2name.update({i: p for i, p in enumerate(group["params"], start_index) if i not in id2name}) + end_num = len(id2name) + start_index += end_num - start_num + + for g in full_optimizer_state["param_groups"]: + get_index_mapping(g) + + new_state = {} + for key, value in fsdp_optim_dict["state"].items(): + new_state[id2name[int(key)]] = value + fsdp_optim_dict["state"] = new_state + for g in fsdp_optim_dict["param_groups"]: + new_group = [] + for param_id in g["params"]: + new_group.append(id2name[param_id]) + g["params"] = new_group + with FSDP.state_dict_type(optimizer.unwrap_model().unwrap(), StateDictType.FULL_STATE_DICT): fsdp_state = FSDP.optim_state_dict_to_load( model=optimizer.unwrap_model().unwrap(), optim=optimizer, optim_state_dict=fsdp_optim_dict diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 54da168e5..f6bf1bb4a 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -8,10 +8,12 @@ from typing import Optional import torch.nn as nn from torch.optim import Optimizer +from colossalai.utils.safetensors import load_flat + from .checkpoint_io_base import CheckpointIO from .index_file import CheckpointIndexFile from .utils import ( - async_save_state_dict_shards, + async_move_save_state_dict_shards, create_pinned_state_dict, get_model_base_filenames, get_optimizer_base_filenames, @@ -47,10 +49,6 @@ class GeneralCheckpointIO(CheckpointIO): ): state_dict = model.state_dict() - # TODO(FrankLeeeee): add support for gather_dtensor - if gather_dtensor: - pass - if use_async: from colossalai.utils.safetensors import move_and_save @@ -58,7 +56,6 @@ class GeneralCheckpointIO(CheckpointIO): self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) self.async_writers.append(writer) - else: # save the checkpoint save_state_dict(state_dict, checkpoint, use_safetensors) @@ -83,7 +80,10 @@ class GeneralCheckpointIO(CheckpointIO): checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames() for shard_file in checkpoint_files: - state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) + if shard_file.endswith(".safetensors"): + state_dict = load_flat(shard_file) + else: + state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False) load_states_into_optimizer(optimizer, state_dict, id_map) sharded_optimizer_loading_epilogue(optimizer) @@ -116,7 +116,7 @@ class GeneralCheckpointIO(CheckpointIO): sharded_state = shard_optimizer_checkpoint(state_dict, max_shard_size=size_per_shard) # Preparing file paths and index file. - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) # Store the information of param groups to param_group_file. @@ -126,14 +126,28 @@ class GeneralCheckpointIO(CheckpointIO): # Save shards of optimizer states. # In general cases, is_master is set to True to get the right behavior. - total_size = save_state_dict_shards( - sharded_state_dict=sharded_state, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=True, - use_safetensors=False, - ) + if use_async: + pinned_state_dict = self.pinned_state_dicts.get(id(optimizer), None) + total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + pinned_state_dict=pinned_state_dict, + state_preprocess=True, + ) + self.pinned_state_dicts[id(optimizer)] = new_pinned_state_dict + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=sharded_state, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=True, + use_safetensors=False, + ) # Wrap up index file. index_file.append_meta_data("total_size", total_size) @@ -145,7 +159,10 @@ class GeneralCheckpointIO(CheckpointIO): ) def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path): - checkpoint = load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + checkpoint = load_flat(checkpoint) + else: + checkpoint = load_state_dict(checkpoint) optimizer.load_state_dict(checkpoint) def save_unsharded_optimizer( @@ -156,7 +173,22 @@ class GeneralCheckpointIO(CheckpointIO): use_async: bool = False, ): # TODO(FrankLeeeee): handle distributed tensors - save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False) + state_dict = optimizer.state_dict() + if use_async: + from colossalai.utils.safetensors import _flatten_optim_state_dict, move_and_save + + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = create_pinned_state_dict(flatten_state_dict) + writer = move_and_save( + path=checkpoint, + state_dict=flatten_state_dict, + state_dict_pinned=self.pinned_state_dicts[id(optimizer)], + metadata=metadata, + ) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def save_sharded_model( self, @@ -186,7 +218,7 @@ class GeneralCheckpointIO(CheckpointIO): if use_async: pinned_state_dict = self.pinned_state_dicts.get(id(model), None) - total_size, new_pinned_state_dict, writers = async_save_state_dict_shards( + total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint_path, index_file=index_file, diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index e0701a247..0a2e598ca 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -22,6 +22,7 @@ from colossalai.tensor.padded_tensor import ( to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set +from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -69,6 +70,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: @@ -76,9 +78,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): self.global_dp_group = dp_group self.pp_group = pp_group self.tp_group = tp_group + self.sp_group = sp_group self.dp_rank = dist.get_rank(self.global_dp_group) self.tp_rank = dist.get_rank(self.tp_group) self.pp_rank = dist.get_rank(self.pp_group) + self.sp_rank = dist.get_rank(self.sp_group) self.global_dp_size = dist.get_world_size(dp_group) self.pp_size = dist.get_world_size(pp_group) self.tp_size = dist.get_world_size(tp_group) @@ -88,7 +92,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): @staticmethod def _model_sharder( - model: nn.Module, prefix: str = "", keep_vars: bool = False, size_per_shard: int = 1024 + model: nn.Module, + prefix: str = "", + keep_vars: bool = False, + size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: # An internel method that breaks state_dict of model into shards within limited size. @@ -102,6 +110,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if is_padded_tensor(param): param = to_unpadded_tensor(param) param_ = gather_distributed_param(param, keep_vars=False) + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(param_) + param_ = pinned_state_dicts[prefix + name] block, block_size = state_dict_sharder.append_param(prefix + name, param_) if block is not None: yield block, block_size @@ -111,6 +124,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for name, buf in model.named_buffers(): if buf is not None and name not in non_persist_buffers_set: buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] block, block_size = state_dict_sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -122,6 +140,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): is not torch.nn.Module.get_extra_state ): extra_state = model.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(param_, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size @@ -136,6 +159,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): dp_group: ProcessGroup, tp_group: ProcessGroup, size_per_shard: int = 1024, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, ): # An internel method that breaks state_dict of optimizer into shards within limited size. @@ -153,6 +177,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): working_param = param param_id = param_info["param2id"][id(working_param)] + if pinned_state_dicts is not None: + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} original_shape = param_info["param2shape"][id(working_param)] state_ = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( state, @@ -162,6 +189,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): tp_group=tp_group, use_zero=use_zero, inplace=False, + pinned_state_dicts=pinned_state_dicts[param_id] if pinned_state_dicts is not None else None, ) block, block_size = state_dict_sharder.append_optim_state(param_id, state_) @@ -216,15 +244,31 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Then collect the sharded parameters & buffers along tp_group. # Only devices with tp_rank == 0 are responsible for model saving. - state_dict_shard = HybridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard) + control_saving = self.tp_rank == 0 and self.sp_rank == 0 + if control_saving and use_async: + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(model)] + else: + pinned_state_dicts = None + state_dict_shard = HybridParallelCheckpointIO._model_sharder( + model, size_per_shard=size_per_shard, pinned_state_dicts=pinned_state_dicts + ) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + state_preprocess=False, + ) + self.async_writers.extend(writers) else: total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -234,16 +278,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): is_master=control_saving, use_safetensors=use_safetensors, ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -259,24 +303,25 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) if use_async: - total_size, returned_state_dict, writers = async_save_state_dict_shards( + total_size, writers = async_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=weights_name, is_master=control_saving, - use_pp_format=True, - n_write_entries=191, + state_preprocess=False, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, ) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - use_pp_format=True, - ) if control_saving: assert ( @@ -448,26 +493,46 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # Then collect the sharded states along dp_group(if using zero)/tp_group. # Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving. + control_saving = self.dp_rank == 0 and self.tp_rank == 0 and self.sp_rank == 0 + + if use_async and control_saving: + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(optimizer)] = {} + pinned_state_dicts = self.pinned_state_dicts[id(optimizer)] + else: + pinned_state_dicts = None state_dict_shard = HybridParallelCheckpointIO._optimizer_sharder( optimizer, use_zero=self.use_zero, dp_group=self.global_dp_group, tp_group=self.tp_group, size_per_shard=size_per_shard, + pinned_state_dicts=pinned_state_dicts, ) - states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix) + states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.dp_rank == 0 and self.tp_rank == 0 if self.pp_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + ) if control_saving: # Store param groups. @@ -498,18 +563,33 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True) # Manage filenames of sharded weights and index file for each pipeline stage. - states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + if not use_async: + states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin") + else: + states_name = states_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors") save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=states_name, - is_master=control_saving, - use_pp_format=True, - ) + if use_async: + total_size, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + state_preprocess=True, + ) + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=states_name, + is_master=control_saving, + use_pp_format=True, + ) if control_saving: assert ( @@ -622,7 +702,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): continue file_path = os.path.join(ckpt_root_path, filename) - state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + if file_path.endswith(".safetensors"): + state_dict = load_flat(file_path) + else: + state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) @@ -672,7 +755,15 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # When pipeline is not used, let master rank directly save the collected state_dict. if self.tp_rank == 0: if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async) + from colossalai.utils.safetensors import save + + if id(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + for name, param in state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint, state_dict=state_dict) + self.async_writers.append(writer) else: save_state_dict(state_dict, checkpoint, use_safetensors) else: @@ -686,12 +777,14 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) if use_async: - - from colossalai.utils.safetensors import move_and_save + from colossalai.utils.safetensors import save if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) + for name, param in complete_state_dict.items(): + self.pinned_state_dicts[id(model)][name].copy_(param) + complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] + writer = save(path=checkpoint, state_dict=complete_state_dict) self.async_writers.append(writer) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors) @@ -757,6 +850,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): # gather complete state from tp shards & dp shards param_id = optimizer.param_info["param2id"][id(working_param)] original_shape = optimizer.param_info["param2shape"][id(working_param)] + local_states[param_id] = HybridParallelCheckpointIO.gather_from_sharded_optimizer_state( state, working_param, @@ -776,7 +870,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): ] state_dict = {"param_groups": param_groups, "state": local_states} if self.coordinator.is_master(): - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from colossalai.utils.safetensors import save + + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[k] + writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) else: # When pipeline is used, first collect state_dict from every pipeline stage, then save the complete state_dict. states_list = [None for _ in range(self.pp_size)] @@ -792,7 +898,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_dict = {"param_groups": param_groups, "state": dict()} for _states in states_list: state_dict["state"].update(_states) - save_state_dict(state_dict, checkpoint, use_safetensors=False) + if use_async: + from colossalai.utils.safetensors import save + + flatten_state_dict, metadata = _flatten_optim_state_dict(state_dict) + if id(optimizer) not in self.pinned_state_dicts: + self.pinned_state_dicts = create_pinned_state_dict(flatten_state_dict) + for k, v in flatten_state_dict.items(): + self.pinned_state_dicts[k].copy_(v) + flatten_state_dict[k] = self.pinned_state_dicts[k] + writer = save(path=checkpoint, state_dict=flatten_state_dict, metadata=metadata) + self.async_writers.append(writer) + else: + save_state_dict(state_dict, checkpoint, use_safetensors=False) def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str): """ @@ -818,7 +936,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!" # Complete optimizer state_dict loaded from checkpoint, need to be processed later. - state_dict = load_state_dict(checkpoint) + if checkpoint.endswith(".safetensors"): + state_dict = load_flat(checkpoint) + else: + state_dict = load_state_dict(checkpoint) # Load param_groups. updated_groups = [] @@ -872,6 +993,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): use_zero: bool, inplace: bool, device: torch.device = torch.device("cpu"), + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> OrderedDict: """ With given parameter and its optimizer states, gather the complete optimizer state for saving. @@ -895,6 +1017,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): state_ = state if inplace else copy.deepcopy(state) for k, v in state_.items(): + if v is None: + continue if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: @@ -915,7 +1039,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim) v = to_unpadded_tensor(v) - state_[k] = v.detach().clone().to(device) + if pinned_state_dicts is not None: + if k not in pinned_state_dicts: + pinned_state_dicts[k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[k].copy_(v) + state_[k] = pinned_state_dicts[k] + else: + state_[k] = v.detach().clone().to(device) return state_ diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 3b07856ca..f6aefd33a 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -44,12 +44,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO): global_dp_group: ProcessGroup, pp_group: ProcessGroup, tp_group: ProcessGroup, + sp_group: ProcessGroup, ep_group: ProcessGroup, moe_dp_group: ProcessGroup, zero_stage: int, verbose: bool = True, ) -> None: - super().__init__(global_dp_group, pp_group, tp_group, zero_stage, verbose) + super().__init__(global_dp_group, pp_group, tp_group, sp_group, zero_stage, verbose) self.global_dp_group = global_dp_group self.global_dp_rank = dist.get_rank(global_dp_group) self.global_dp_size = dist.get_world_size(global_dp_group) @@ -158,7 +159,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO): state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) index_file = CheckpointIndexFile(checkpoint) - control_saving = self.tp_rank == 0 + control_saving = self.tp_rank == 0 and self.sp_rank == 0 if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO @@ -415,7 +416,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO): # e.g. dp_size = 4, moe_dp_size = 2, ep_size = 2 and use gather # rank 0 saves moe & non-moe params; rank 1 only saves moe params # rank 3 & 4 save nothing - control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 + control_saving = self.tp_rank == 0 and self.moe_dp_rank == 0 and self.sp_rank == 0 if self.pp_size == 1 and self.ep_size == 1: # When pipeline is not used, save the optimizer shards as in general checkpointIO diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index ab599b556..71422f4c2 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -19,6 +19,7 @@ from colossalai.tensor.d_tensor import ( to_global, to_global_for_customized_distributed_tensor, ) +from colossalai.utils.safetensors import _flatten_optim_state_dict SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -266,6 +267,63 @@ def save_state_dict_shards( def async_save_state_dict_shards( + sharded_state_dict: Iterator[Tuple[OrderedDict, int]], + checkpoint: str, + index_file: "CheckpointIndexFile", + base_filename: str, + is_master: bool, + use_pp_format: bool = False, + state_preprocess: bool = False, +) -> Tuple[int, list]: + """ + Save sharded state dict only on master rank, this method can be used by both model and optimizer states. + Args: + sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size. + checkpoint (str): The path of checkpoint directory as string. + index_file (CheckpointIndexFile): The index file object to be updated. + base_filename (str): Decides the prefix of filenames of shards. + is_master (bool): Whether current rank is main process. + use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False. + use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False. + + Returns: + int: the total size of shards + """ + from colossalai.utils.safetensors import save + + total_size = 0 + shard_filenames = [] + writers = [] + for idx, shard_pair in enumerate(sharded_state_dict): + shard, current_size = shard_pair + # Just loop over the sharder and gather to other ranks if not master + if not is_master: + del shard + continue + shard_file = get_shard_filename(base_filename, idx) + total_size = total_size + current_size + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) + checkpoint_file_path = os.path.join(checkpoint, shard_file) + + if state_preprocess: + state_dict, _ = _flatten_optim_state_dict(state_dict=shard, seperator=".") + else: + state_dict = shard + + # Only save on master rank. + writer = save(checkpoint_file_path, state_dict=state_dict) + writers.append(writer) + shard_filenames.append(shard_file) + del shard + + # Clean folder, deleted unneeded files. + clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format) + + return total_size, writers + + +def async_move_save_state_dict_shards( sharded_state_dict: Iterator[Tuple[OrderedDict, int]], checkpoint: str, index_file: "CheckpointIndexFile", @@ -273,6 +331,7 @@ def async_save_state_dict_shards( is_master: bool, pinned_state_dict: Optional[Dict[str, torch.Tensor]], use_pp_format: bool = False, + state_preprocess: bool = False, ) -> Tuple[int, Dict[str, torch.Tensor], list]: """ Save sharded state dict only on master rank, this method can be used by both model and optimizer states. @@ -309,14 +368,19 @@ def async_save_state_dict_shards( index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) - if pinned_state_dict is not None: - sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()} + if state_preprocess: + state_dict, _ = _flatten_optim_state_dict(state_dict=shard) else: - sub_pinned_state_dict = create_pinned_state_dict(shard) + state_dict = shard + + if pinned_state_dict is not None: + sub_pinned_state_dict = {k: pinned_state_dict[k] for k in state_dict.keys()} + else: + sub_pinned_state_dict = create_pinned_state_dict(state_dict) returned_state_dict.update(sub_pinned_state_dict) # Only save on master rank. - writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict) + writer = move_and_save(checkpoint_file_path, state_dict, sub_pinned_state_dict) writers.append(writer) shard_filenames.append(shard_file) del shard @@ -327,7 +391,11 @@ def async_save_state_dict_shards( return total_size, returned_state_dict, writers -def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: +def shard_model_checkpoint( + state_dict: torch.Tensor, + max_shard_size: int = 1024, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, +) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -336,6 +404,11 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): if not is_distributed_tensor(weight): + if pinned_state_dicts is not None: + if key not in pinned_state_dicts: + pinned_state_dicts[key] = torch.empty_like(weight, pin_memory=True, device="cpu") + pinned_state_dicts[key].copy_(weight) + weight = pinned_state_dicts[key] block, block_size = state_dict_sharder.append_param(key, weight) if block != None: @@ -345,7 +418,9 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) yield state_dict_sharder.current_block, state_dict_sharder.current_block_size -def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: +def shard_optimizer_checkpoint( + state_dict: dict, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None +) -> Iterator[Tuple[OrderedDict, int]]: """ Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a given size. @@ -356,6 +431,15 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> state_dict_sharder = StateDictSharder(max_shard_size) for param_id, state in states.items(): + if pinned_state_dicts is not None: + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} + for k, v in state.items(): + if k not in pinned_state_dicts[param_id]: + pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[param_id][k].copy_(v) + state[k] = pinned_state_dicts[param_id][k] + block, block_size = state_dict_sharder.append_optim_state(param_id, state) if block != None: yield block, block_size diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index d8983436d..8ce6d7335 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -71,6 +71,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d for idx, d in states.items(): for k, v in d.items(): + if v is None: + continue nested_key = f"state{seperator}{idx}{seperator}{k}" if not isinstance(v, torch.Tensor): non_tensor_keys.append(nested_key) @@ -87,7 +89,8 @@ def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[d def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."): state_dict = {} - if metadata is not None: + + if metadata is not None and "non_tensor_keys" in metadata: non_tensor_keys = json.loads(metadata["non_tensor_keys"]) else: non_tensor_keys = [] @@ -128,8 +131,10 @@ def prepare( header = {} offset = 0 + header_metadata = {"format": "pt"} if metadata is not None: - header["__metadata__"] = metadata + header_metadata.update(metadata) + header["__metadata__"] = header_metadata for name, tensor in data.items(): n = tensor.numel() * tensor.element_size() @@ -172,8 +177,9 @@ def move_and_save( path: str, state_dict: Dict[str, torch.Tensor], state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, + metadata: Optional[Dict[str, str]] = None, ) -> None: - prepared_data, _, tensor_keys = prepare(state_dict) + prepared_data, _, tensor_keys = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys)) f_writer.write(n.to_bytes(8, byteorder="little")) @@ -188,9 +194,9 @@ def move_and_save( return f_writer -def load_flat(checkpoint_path): +def load_flat(checkpoint_path, seperator: str = "."): with safe_open(checkpoint_path, framework="pt") as f: metadata = f.metadata() state_dict_load = load_file(checkpoint_path) - state_dict = _unflatten_optim_state_dict(state_dict_load, metadata) + state_dict = _unflatten_optim_state_dict(state_dict_load, metadata, seperator) return state_dict diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index a033e917b..9e89e8827 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -903,6 +903,7 @@ class GeminiDDP(ModelWrapper): keep_vars: bool = False, max_shard_size: int = 1024, only_rank_0: bool = True, + pinned_state_dicts: Optional[Dict[str, torch.Tensor]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. @@ -943,6 +944,13 @@ class GeminiDDP(ModelWrapper): gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0)) gathered_param = gathered_param_buffer.pop(param_to_save) + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like( + gathered_param, pin_memory=True, device="cpu" + ) + pinned_state_dicts[prefix + name].copy_(gathered_param) + gathered_param = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, gathered_param) if block is not None: yield block, block_size @@ -954,6 +962,11 @@ class GeminiDDP(ModelWrapper): for name, buf in self.named_buffers(): if buf is not None and name not in self._non_persistent_buffers_set: buffer = buf if keep_vars else buf.detach() + if pinned_state_dicts is not None: + if (prefix + name) not in pinned_state_dicts: + pinned_state_dicts[prefix + name] = torch.empty_like(buffer, pin_memory=True, device="cpu") + pinned_state_dicts[prefix + name].copy_(buffer) + buffer = pinned_state_dicts[prefix + name] block, block_size = sharder.append_param(prefix + name, buffer) if block is not None: yield block, block_size @@ -964,6 +977,11 @@ class GeminiDDP(ModelWrapper): is not torch.nn.Module.get_extra_state ): extra_state = self.get_extra_state() + if pinned_state_dicts is not None: + if extra_state_key not in pinned_state_dicts: + pinned_state_dicts[extra_state_key] = torch.empty_like(extra_state, pin_memory=True, device="cpu") + pinned_state_dicts[extra_state_key].copy_(extra_state) + extra_state = pinned_state_dicts[extra_state_key] block, block_size = sharder.append_param(extra_state_key, extra_state) if block is not None: yield block, block_size diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index ca91b4d9f..def96b19b 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -809,7 +809,11 @@ class GeminiOptimizer(OptimizerWrapper): self.optimizer_loading_epilogue() def state_shard( - self, prefix: str = "", max_shard_size: int = 1024, only_rank_0: bool = True + self, + prefix: str = "", + max_shard_size: int = 1024, + only_rank_0: bool = True, + pinned_state_dicts: Optional[Dict[int, Dict[str, torch.Tensor]]] = None, ) -> Iterator[Tuple[OrderedDict, int]]: """Returns dictionaries containing shards of optimizer states one by one. The max size of each dictionary shard is specified by ``max_shard_size``. @@ -829,6 +833,16 @@ class GeminiOptimizer(OptimizerWrapper): dist.barrier() state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) + if pinned_state_dicts is not None: + if param_id not in pinned_state_dicts: + pinned_state_dicts[param_id] = {} + for k, v in state.items(): + if v is None: + continue + if k not in pinned_state_dicts[param_id]: + pinned_state_dicts[param_id][k] = torch.empty_like(v, pin_memory=True, device="cpu") + pinned_state_dicts[param_id][k].copy_(v) + state[k] = pinned_state_dicts[param_id][k] block, block_size = sharder.append_optim_state(param_id, state) if block is not None: yield block, block_size diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 8bee8fe97..a6d65cae5 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -35,7 +35,10 @@ OPTIM_PLACEMENT_CONFIGS = [ @parameterize("use_safetensors", [False, True]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int): +@parameterize("use_async", [False, True]) +def exam_state_dict_with_origin( + placement_config, model_name, use_safetensors: bool, tp_size: int, zero_size: int, use_async: bool +): from transformers import BertForSequenceClassification (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) @@ -70,7 +73,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b "", (model_size / 3), use_safetensors=use_safetensors, + use_async=use_async, ) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) check_state_dict_equal(bert_model.state_dict(only_rank_0=False), new_bert_model.state_dict()) @@ -83,7 +89,10 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b @parameterize("size_per_shard", [32]) @parameterize("tp_size", [1, 2]) @parameterize("zero_size", [2]) -def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int): +@parameterize("use_async", [False, True]) +def exam_state_dict( + placement_config, shard: bool, model_name: str, size_per_shard: int, tp_size: int, zero_size: int, use_async: bool +): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() enable_flash_attention = True if tp_size > 1 else False @@ -124,14 +133,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model( - model, - model_ckpt_path, - shard=shard, - size_per_shard=size_per_shard, - ) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() booster.load_model(new_model, model_ckpt_path) @@ -155,8 +168,18 @@ def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_sha loss = criterion(output[output_key]) booster.backward(loss, new_optimizer) new_optimizer.step() - booster.save_model(new_model, model_ckpt_path, shard=shard) - booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard) + + with shared_tempdir() as new_tempdir: + model_ckpt_path = f"{new_tempdir}/model" + optimizer_ckpt_path = f"{new_tempdir}/optimizer" + + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + booster.save_model(new_model, model_ckpt_path, shard=shard, use_async=use_async) + booster.save_optimizer(new_optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() def exam_lazy_from_pretrained(): diff --git a/tests/test_checkpoint_io/test_general_checkpoint_io.py b/tests/test_checkpoint_io/test_general_checkpoint_io.py index 8431036df..327be0bb7 100644 --- a/tests/test_checkpoint_io/test_general_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_general_checkpoint_io.py @@ -19,7 +19,8 @@ from colossalai.testing import check_state_dict_equal, clear_cache_before_run, p @clear_cache_before_run() @parameterize("use_safetensors", [True, False]) -def test_unsharded_checkpoint(use_safetensors: bool): +@parameterize("use_async", [False, True]) +def test_unsharded_checkpoint(use_safetensors: bool, use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -36,18 +37,21 @@ def test_unsharded_checkpoint(use_safetensors: bool): lr_scheduler.step() # create a temp file for checkpoint - if use_safetensors: + if use_async or use_safetensors: suffix = ".safetensors" else: suffix = ".bin" model_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) - optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() + if use_async: + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile(suffix=suffix) + else: + optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() lr_scheduler_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model, optimizer, lr_scheduler ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name) + ckpt_io.save_model(model, model_ckpt_tempfile.name, use_safetensors=use_safetensors, use_async=use_async) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, use_async=use_async) ckpt_io.save_lr_scheduler(lr_scheduler, lr_scheduler_ckpt_tempfile.name) # create new model @@ -55,6 +59,9 @@ def test_unsharded_checkpoint(use_safetensors: bool): new_optimizer = Adam(new_model.parameters(), lr=0.001) new_lr_scheduler = CosineAnnealingWarmupLR(optimizer, total_steps=10) + ckpt_io._sync_d2h() + ckpt_io._sync_io() + # load the model, optimizer, lr_scheduler ckpt_io.load_model(new_model, model_ckpt_tempfile.name) ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name) @@ -66,7 +73,8 @@ def test_unsharded_checkpoint(use_safetensors: bool): @pytest.mark.parametrize("use_safetensors", [True, False]) -def test_sharded_model_checkpoint(use_safetensors: bool): +@pytest.mark.parametrize("use_async", [False, True]) +def test_sharded_model_checkpoint(use_safetensors: bool, use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -79,21 +87,20 @@ def test_sharded_model_checkpoint(use_safetensors: bool): loss.backward() optimizer.step() - # create a temp file for checkpoint - if use_safetensors: - pass - else: - pass - model_ckpt_dir = tempfile.TemporaryDirectory() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() # save the model and optimizer ckpt_io = GeneralCheckpointIO() - ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors) + ckpt_io.save_model( + model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=use_safetensors, use_async=use_async + ) ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name, shard=False) + ckpt_io._sync_d2h() + ckpt_io._sync_io() + # create new model new_model = resnet18() new_optimizer = Adam(new_model.parameters(), lr=0.001) @@ -106,7 +113,8 @@ def test_sharded_model_checkpoint(use_safetensors: bool): check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict()) -def test_sharded_optimizer_checkpoint(): +@pytest.mark.parametrize("use_async", [False, True]) +def test_sharded_optimizer_checkpoint(use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam(model.parameters(), lr=0.001) @@ -128,7 +136,10 @@ def test_sharded_optimizer_checkpoint(): ckpt_io = GeneralCheckpointIO() ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async) + + ckpt_io._sync_d2h() + ckpt_io._sync_io() # create new model new_model = resnet18() @@ -148,9 +159,16 @@ def test_sharded_optimizer_checkpoint(): loss.backward() new_optimizer.step() + # create temp directories for checkpoint + model_ckpt_dir = tempfile.TemporaryDirectory() + optimizer_ckpt_dir = tempfile.TemporaryDirectory() + # save the newly got optimizer ckpt_io.save_model(new_model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) - ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + ckpt_io.save_optimizer(new_optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async) + + ckpt_io._sync_d2h() + ckpt_io._sync_io() # create another new model new_new_model = resnet18() @@ -164,7 +182,8 @@ def test_sharded_optimizer_checkpoint(): check_state_dict_equal(new_optimizer.state_dict(), new_new_optimizer.state_dict()) -def test_sharded_optimizer_multiple_param_groups(): +@pytest.mark.parametrize("use_async", [False, True]) +def test_sharded_optimizer_multiple_param_groups(use_async: bool): # create a model and optimizer model = resnet18() optimizer = Adam( @@ -188,7 +207,10 @@ def test_sharded_optimizer_multiple_param_groups(): ckpt_io = GeneralCheckpointIO() ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "", 10, use_safetensors=False) - ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10) + ckpt_io.save_optimizer(optimizer, optimizer_ckpt_dir.name, shard=True, size_per_shard=10, use_async=use_async) + + ckpt_io._sync_d2h() + ckpt_io._sync_io() # create new model new_model = resnet18() diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 86d7924fb..81d184f76 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -38,12 +38,13 @@ else: ] -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@parameterize("use_async", [False, True]) @clear_cache_before_run() -def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) ) @@ -85,8 +86,16 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = model_fn().cuda() diff --git a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py index 87d35f252..b90ea0960 100644 --- a/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_ddp_checkpoint_io.py @@ -12,14 +12,15 @@ from colossalai.interface import OptimizerWrapper from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -@parameterize("shard", [True, False]) +@parameterize("shard", [False, True]) @parameterize("size_per_shard", [16, 128]) -def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): +@parameterize("use_async", [False, True]) +def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int, use_async: bool): plugin = TorchDDPPlugin() booster = Booster(plugin=plugin) model = resnet18() criterion = lambda x: x.mean() - optimizer = SGD((model.parameters()), lr=0.001) + optimizer = SGD((model.parameters()), lr=0.001, momentum=0.5) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1) model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler) @@ -39,9 +40,18 @@ def check_torch_ddp_checkpointIO(shard: bool, size_per_shard: int): model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) - booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + + if not shard and use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optimizer_ckpt_path = f"{optimizer_ckpt_path}.safetensors" + + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) + booster.save_optimizer( + optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async + ) booster.save_lr_scheduler(scheduler, lr_scheduler_ckpt_path) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = resnet18() diff --git a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py index 12b70cc04..25d901538 100644 --- a/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py @@ -12,7 +12,7 @@ if version.parse(torch.__version__) >= version.parse("1.12.0"): from colossalai.booster.plugin import TorchFSDPPlugin from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn def compare_nested_dict(dict1, dict2): @@ -43,7 +43,8 @@ def compare_nested_dict(dict1, dict2): return True -def check_torch_fsdp_ckpt(): +@parameterize("use_async", [False, True]) +def check_torch_fsdp_ckpt(use_async: bool): model = resnet18() plugin = TorchFSDPPlugin() booster = Booster(plugin=plugin) @@ -65,10 +66,17 @@ def check_torch_fsdp_ckpt(): model_ckpt_path = f"{tempdir}/model" optim_ckpt_path = f"{tempdir}/optimizer" + if use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + optim_ckpt_path = f"{optim_ckpt_path}.safetensors" + run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=False) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=False) + booster.save_model(fsdp_model, model_ckpt_path, shard=False, use_async=use_async) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=False, use_async=use_async) + + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() full_msd = fsdp_model.state_dict() # full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer) @@ -106,8 +114,11 @@ def check_torch_fsdp_ckpt(): run_model() - booster.save_model(fsdp_model, model_ckpt_path, shard=True) - booster.save_optimizer(optimizer, optim_ckpt_path, shard=True) + booster.save_model(fsdp_model, model_ckpt_path, shard=True, use_async=use_async) + booster.save_optimizer(optimizer, optim_ckpt_path, shard=True, use_async=use_async) + + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() full_msd = fsdp_model.unwrap().state_dict() full_osd = FSDP.full_optim_state_dict(optimizer.unwrap_model().unwrap(), optim=optimizer)