From b60be18dcca5c88263bd704fb256c36dd5729904 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Sat, 27 Jan 2024 16:06:33 +0800 Subject: [PATCH] [moe] fix mixtral checkpoint io (#5314) --- .../colossal_moe/models/mixtral_checkpoint.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py index 635eebd89..629ad7349 100644 --- a/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py +++ b/applications/ColossalMoE/colossal_moe/models/mixtral_checkpoint.py @@ -135,6 +135,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): Path(checkpoint).mkdir(parents=True, exist_ok=True) if self.real_dp_rank != 0: + dist.barrier() return # ep_rank 0 saves all the parameters and buffers. @@ -171,6 +172,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): f"index located at {save_index_file}." ) + dist.barrier() else: # When pipeline is used, each stage produces its own shard files and index files. # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ @@ -201,10 +203,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) else: + dist.barrier() return - dist.barrier(self.pp_group) - dist.barrier(self.ep_group) + dist.barrier() # The global master rank integrates the index files and clean the folder. if self.coordinator.is_master(): @@ -360,6 +362,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): # Devices along the same dp_group share the same copies of states when zero is not used. # In this case only let the device with dp_rank == 0 save the model. if not self.use_zero and self.real_dp_rank != 0: + dist.barrier() return # Then collect the sharded states along dp_group(if using zero)/tp_group. @@ -401,6 +404,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): f"index located at {save_index_file}." ) + dist.barrier() else: # When pipeline is used, each stage produces its own shard files and index files. # Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/ @@ -428,10 +432,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO): index_file.append_meta_data("total_size", total_size) index_file.write_index_file(save_index_file) else: + dist.barrier() return - dist.barrier(self.pp_group) - dist.barrier(self.ep_group) + dist.barrier() # The global master rank integrates the index files and clean the folder. if self.coordinator.is_master():