diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index d66171c58..394274086 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -9,6 +9,7 @@ from typing import Dict, Iterator, Optional, OrderedDict, Tuple import torch import torch.distributed as dist import torch.nn as nn +from tensornvme.async_file_io import AsyncFileWriter from torch.distributed import ProcessGroup from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils._pytree import tree_map @@ -22,6 +23,7 @@ from colossalai.tensor.padded_tensor import ( to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set +from colossalai.utils.safetensors import move_and_save from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -199,7 +201,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False. use_async (bool, optional): Whether to save the state_dicts of model asynchronously. Defaults to False. """ - assert isinstance(model, ModelWrapper), "Please boost the model before saving!" model._force_wait_all_gather() model = model.unwrap() @@ -224,7 +225,18 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if self.pp_size == 1: # When pipeline is not used, save the model shards as in general checkpointIO if use_async: - super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors, use_async=use_async) + pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + total_size, pinned_state_dict, writers = async_save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + n_write_entries=self.N_WRITE_ENTRIES, + ) + self.pinned_state_dicts[id(model)] = pinned_state_dict + self.async_writers.extend(writers) else: total_size = save_state_dict_shards( sharded_state_dict=state_dict_shard, @@ -234,16 +246,16 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): is_master=control_saving, use_safetensors=use_safetensors, ) - if control_saving: - index_file.append_meta_data("total_size", total_size) - index_file.write_index_file(save_index_file) - save_config_file(model, checkpoint) - if self.verbose and self.coordinator.is_master(): - logging.info( - f"The model is split into checkpoint shards. " - f"You can find where each parameters has been saved in the " - f"index located at {save_index_file}." - ) + if control_saving: + index_file.append_meta_data("total_size", total_size) + index_file.write_index_file(save_index_file) + save_config_file(model, checkpoint) + if self.verbose and self.coordinator.is_master(): + logging.info( + f"The model is split into checkpoint shards. " + f"You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) else: # When pipeline is used, each stage produces its own shard files and index files. @@ -259,24 +271,28 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json") save_index_file = os.path.join("tmp_index_files", save_index_file) if use_async: - total_size, returned_state_dict, writers = async_save_state_dict_shards( + pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + total_size, pinned_state_dict, writers = async_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint, index_file=index_file, base_filename=weights_name, is_master=control_saving, - use_pp_format=True, - n_write_entries=191, + use_safetensors=use_safetensors, + n_write_entries=self.N_WRITE_ENTRIES, + ) + self.pinned_state_dicts[id(model)] = pinned_state_dict + self.async_writers.extend(writers) + else: + total_size = save_state_dict_shards( + sharded_state_dict=state_dict_shard, + checkpoint=checkpoint, + index_file=index_file, + base_filename=weights_name, + is_master=control_saving, + use_safetensors=use_safetensors, + use_pp_format=True, ) - total_size = save_state_dict_shards( - sharded_state_dict=state_dict_shard, - checkpoint=checkpoint, - index_file=index_file, - base_filename=weights_name, - is_master=control_saving, - use_safetensors=use_safetensors, - use_pp_format=True, - ) if control_saving: assert ( @@ -664,7 +680,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): model = model.unwrap() if self.dp_rank != 0: return - # The logic of collecting parameter shards along tp degree # has been implemented by _save_to_state_dict method of ParallelModule in Shardformer. state_dict = model.state_dict() @@ -686,15 +701,12 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) if use_async: - from tensornvme.async_file_io import AsyncFileWriter - - from colossalai.utils.safetensors import move_and_save writer = AsyncFileWriter(open(checkpoint, "wb"), self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) + move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)]) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors) diff --git a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py index 86d7924fb..ea3e2aacd 100644 --- a/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py @@ -34,6 +34,7 @@ if Version(torch.__version__) < Version("2.0.0"): else: TEST_CONFIGS = [ # TODO(ver217): other configs lead to hang + {"tp_size": 1, "pp_size": 1, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, {"tp_size": 1, "pp_size": 2, "num_microbatches": 4, "zero_stage": 1, "precision": "fp16", "initial_scale": 1}, ] @@ -42,8 +43,9 @@ else: @parameterize("model_name", ["transformers_llama_for_causal_lm"]) @parameterize("size_per_shard", [32]) @parameterize("test_config", TEST_CONFIGS) +@parameterize("use_async", [True, False]) @clear_cache_before_run() -def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict): +def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict, use_async: bool): (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) = next( iter(model_zoo.get_sub_registry(model_name).values()) ) @@ -85,8 +87,14 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + if not use_async: + model_ckpt_path = f"{model_ckpt_path}.pt" + if use_async: + model_ckpt_path = f"{model_ckpt_path}.safetensors" + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard, use_async=use_async) booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) + booster.checkpoint_io._sync_d2h() + booster.checkpoint_io._sync_io() dist.barrier() new_model = model_fn().cuda()