diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index a3039f97b..6e91bd8ed 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -137,12 +137,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): state_dict = optimizer.state_dict(pinned_state_dicts, only_on_master=True) if self.coordinator.is_master(): if use_async: - from tensornvme.async_file_io import AsyncFileWriter from colossalai.utils.safetensors import save_nested - f_writer = AsyncFileWriter(checkpoint, n_entries=self.N_WRITE_ENTRIES, backend="pthread") - save_nested(f_writer, state_dict) + f_writer = save_nested(checkpoint, state_dict) self.async_writers.append(f_writer) else: save_state_dict(state_dict, checkpoint, use_safetensors=False) @@ -222,16 +220,10 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): checkpoint_file_path = os.path.join(checkpoint, shard_file) if self.coordinator.is_master(): if use_async: - from tensornvme.async_file_io import AsyncFileWriter from colossalai.utils.safetensors import save_nested - f_writer = AsyncFileWriter( - checkpoint_file_path, - n_entries=self.N_WRITE_ENTRIES, - backend="pthread", - ) - save_nested(f_writer, shard) + f_writer = save_nested(checkpoint_file_path, shard) self.async_writers.append(f_writer) else: save_state_dict(shard, checkpoint_file_path, use_safetensors=False) diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index b96c0c7b8..c67020e97 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -59,8 +59,6 @@ class CheckpointIO(ABC): >>> checkpoint_io.save_optimizer(optimizer, 'optimizer.pt') """ - N_WRITE_ENTRIES: int = 32 - # ====================================== # Public methods # ====================================== diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 254580677..3bb805131 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -54,13 +54,11 @@ class GeneralCheckpointIO(CheckpointIO): pass if use_async: - from tensornvme.async_file_io import AsyncFileWriter - writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) else: # save the checkpoint @@ -196,7 +194,6 @@ class GeneralCheckpointIO(CheckpointIO): base_filename=weights_name, is_master=True, pinned_state_dict=pinned_state_dict, - n_write_entries=self.N_WRITE_ENTRIES, ) self.pinned_state_dicts[id(model)] = new_pinned_state_dict self.async_writers.extend(writers) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index da3199e12..e0701a247 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -686,15 +686,13 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): for _state_dict in state_dict_list: complete_state_dict.update(_state_dict) if use_async: - from tensornvme.async_file_io import AsyncFileWriter from colossalai.utils.safetensors import move_and_save - writer = AsyncFileWriter(checkpoint, self.N_WRITE_ENTRIES, backend="pthread") if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) self.async_writers.append(writer) - move_and_save(writer, state_dict, self.pinned_state_dicts[id(model)]) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 77b9faa0b..5ef0bd354 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -273,7 +273,6 @@ def async_save_state_dict_shards( base_filename: str, is_master: bool, pinned_state_dict: Optional[Dict[str, torch.Tensor]], - n_write_entries: int, use_pp_format: bool = False, ) -> Tuple[int, Dict[str, torch.Tensor], list]: """ @@ -290,7 +289,6 @@ def async_save_state_dict_shards( Returns: int: the total size of shards """ - from tensornvme.async_file_io import AsyncFileWriter total_size = 0 shard_filenames = [] @@ -311,9 +309,6 @@ def async_save_state_dict_shards( index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) - writer = AsyncFileWriter(checkpoint_file_path, n_write_entries, backend="pthread") - writers.append(writer) - if pinned_state_dict is not None: sub_pinned_state_dict = {k: pinned_state_dict[k] for k in shard.keys()} else: @@ -321,7 +316,8 @@ def async_save_state_dict_shards( returned_state_dict.update(sub_pinned_state_dict) # Only save on master rank. - move_and_save(writer, shard, sub_pinned_state_dict) + writer = move_and_save(checkpoint_file_path, shard, sub_pinned_state_dict) + writers.append(writer) shard_filenames.append(shard_file) del shard diff --git a/colossalai/utils/safetensors.py b/colossalai/utils/safetensors.py index 8b8cb627f..d8983436d 100644 --- a/colossalai/utils/safetensors.py +++ b/colossalai/utils/safetensors.py @@ -15,6 +15,8 @@ import io from torch.distributed.distributed_c10d import _pickler, _unpickler +ASYNC_WRITE_ENTRIES = 32 + def _object_to_tensor(obj, device): f = io.BytesIO() @@ -149,32 +151,31 @@ def prepare( return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys -def save( - f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None -) -> None: +def save(path: str, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None) -> None: prepared_data, tensors, _ = prepare(state_dict, metadata) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset - + f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensors)) f_writer.write(n.to_bytes(8, byteorder="little")) f_writer.write(header_bytes) for tensor in tensors: f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) + return f_writer -def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None: +def save_nested(path: str, state_dict: Dict[str, torch.Tensor]) -> None: flatten_data, metadata = _flatten_optim_state_dict(state_dict) - save(f_writer, flatten_data, metadata) + return save(path, flatten_data, metadata) def move_and_save( - f_writer: AsyncFileWriter, + path: str, state_dict: Dict[str, torch.Tensor], state_dict_pinned: Optional[Dict[str, torch.Tensor]] = None, ) -> None: prepared_data, _, tensor_keys = prepare(state_dict) n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset - + f_writer = AsyncFileWriter(path, n_entries=ASYNC_WRITE_ENTRIES, backend="pthread", n_tasks=2 + len(tensor_keys)) f_writer.write(n.to_bytes(8, byteorder="little")) f_writer.write(header_bytes) @@ -184,6 +185,7 @@ def move_and_save( f_writer.write_tensor(state_dict[name], state_dict_pinned[name]) else: f_writer.write_tensor(state_dict[name]) + return f_writer def load_flat(checkpoint_path): diff --git a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py index 452080a49..eb17d4c10 100644 --- a/colossalai/zero/low_level/bookkeeping/tensor_bucket.py +++ b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py @@ -83,7 +83,11 @@ class TensorBucket: unflat_buffers = list(map(list, zip(*unflat_buffers))) for unflat_shards, tensor in zip(unflat_buffers, self._bucket): write_back_tensor = self._write_back_pairs[tensor] - write_back_tensor.data.copy_( - _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()].reshape_as(write_back_tensor) - ) + rec_tensor = _flatten_dense_tensors(unflat_shards)[: write_back_tensor.numel()] + if write_back_tensor.is_contiguous(): + rec_tensor = rec_tensor.view_as(write_back_tensor) + else: + rec_tensor = rec_tensor.reshape_as(write_back_tensor) + write_back_tensor.data.copy_(rec_tensor) + self.empty() diff --git a/tests/test_checkpoint_io/test_safetensors_async_io.py b/tests/test_checkpoint_io/test_safetensors_async_io.py index 882b5e2c5..7de73b46b 100644 --- a/tests/test_checkpoint_io/test_safetensors_async_io.py +++ b/tests/test_checkpoint_io/test_safetensors_async_io.py @@ -3,18 +3,12 @@ import tempfile import torch from safetensors.torch import load_file +from colossalai.testing import check_state_dict_equal, clear_cache_before_run +from colossalai.utils import get_current_device from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested -try: - from tensornvme.async_file_io import AsyncFileWriter -except ModuleNotFoundError: - raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") - - -from colossalai.testing import check_state_dict_equal -from colossalai.utils import get_current_device - +@clear_cache_before_run() def test_save_load(): with tempfile.TemporaryDirectory() as tempdir: optimizer_state_dict = { @@ -111,8 +105,7 @@ def test_save_load(): } optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" - f_writer = AsyncFileWriter(optimizer_saved_path, n_entries=191, backend="pthread") - save_nested(f_writer, optimizer_state_dict) + f_writer = save_nested(optimizer_saved_path, optimizer_state_dict) f_writer.sync_before_step() f_writer.synchronize() del f_writer @@ -120,8 +113,7 @@ def test_save_load(): check_state_dict_equal(load_state_dict, optimizer_state_dict) optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors" - f_writer = AsyncFileWriter(optimizer_shard_saved_path, n_entries=191, backend="pthread") - save_nested(f_writer, optimizer_state_dict["state"]) + f_writer = save_nested(optimizer_shard_saved_path, optimizer_state_dict["state"]) f_writer.sync_before_step() f_writer.synchronize() del f_writer @@ -134,8 +126,7 @@ def test_save_load(): "module.weight2": torch.rand((1024, 1024)), } model_saved_path = f"{tempdir}/save_model.safetensors" - f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread") - save(f_writer, model_state_dict) + f_writer = save(model_saved_path, model_state_dict) f_writer.sync_before_step() f_writer.synchronize() del f_writer @@ -145,8 +136,7 @@ def test_save_load(): model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()} model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()} model_saved_path = f"{tempdir}/save_model_cuda.safetensors" - f_writer = AsyncFileWriter(model_saved_path, n_entries=191, backend="pthread") - move_and_save(f_writer, model_state_dict_cuda, model_state_pinned) + f_writer = move_and_save(model_saved_path, model_state_dict_cuda, model_state_pinned) f_writer.sync_before_step() f_writer.synchronize() del f_writer diff --git a/tests/test_optimizer/test_dist_lamb.py b/tests/test_optimizer/test_dist_lamb.py index 66e8e49c7..615c5c33c 100644 --- a/tests/test_optimizer/test_dist_lamb.py +++ b/tests/test_optimizer/test_dist_lamb.py @@ -10,7 +10,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import DistributedLamb, Lamb from colossalai.tensor.d_tensor import get_shard_dim_1d, is_distributed_tensor from colossalai.tensor.d_tensor.api import clear_layout_converter -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer from tests.kit.model_zoo import model_zoo @@ -108,6 +108,7 @@ def set_dist_grad( @parameterize("p_g_dtype", _ALLOWED_P_G_TYPES) @parameterize("bias_correction", [False, True]) @parameterize("tp_zero_size", [(1, 4), (4, 1), (2, 2)]) +@clear_cache_before_run() def run_dist_lamb_basic( bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int] ) -> None: @@ -177,6 +178,7 @@ def run_dist_lamb_basic( @parameterize("p_g_dtype", _ALLOWED_P_G_TYPES) @parameterize("bias_correction", [False, True]) @parameterize("tp_zero_size", [(2, 2), (4, 1), (1, 4)]) +@clear_cache_before_run() def run_dist_lamb_fwd_bwd( bias_correction: bool, p_g_dtype: tuple[torch.dtype, torch.dtype], tp_zero_size: tuple[int, int] ) -> None: