mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[moe] fix mixtral checkpoint io (#5314)
This commit is contained in:
parent
da39d21b71
commit
b60be18dcc
@ -135,6 +135,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||||||
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
Path(checkpoint).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
if self.real_dp_rank != 0:
|
if self.real_dp_rank != 0:
|
||||||
|
dist.barrier()
|
||||||
return
|
return
|
||||||
|
|
||||||
# ep_rank 0 saves all the parameters and buffers.
|
# ep_rank 0 saves all the parameters and buffers.
|
||||||
@ -171,6 +172,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
else:
|
else:
|
||||||
# When pipeline is used, each stage produces its own shard files and index files.
|
# When pipeline is used, each stage produces its own shard files and index files.
|
||||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||||
@ -201,10 +203,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
else:
|
else:
|
||||||
|
dist.barrier()
|
||||||
return
|
return
|
||||||
|
|
||||||
dist.barrier(self.pp_group)
|
dist.barrier()
|
||||||
dist.barrier(self.ep_group)
|
|
||||||
|
|
||||||
# The global master rank integrates the index files and clean the folder.
|
# The global master rank integrates the index files and clean the folder.
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
@ -360,6 +362,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||||||
# Devices along the same dp_group share the same copies of states when zero is not used.
|
# Devices along the same dp_group share the same copies of states when zero is not used.
|
||||||
# In this case only let the device with dp_rank == 0 save the model.
|
# In this case only let the device with dp_rank == 0 save the model.
|
||||||
if not self.use_zero and self.real_dp_rank != 0:
|
if not self.use_zero and self.real_dp_rank != 0:
|
||||||
|
dist.barrier()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
# Then collect the sharded states along dp_group(if using zero)/tp_group.
|
||||||
@ -401,6 +404,7 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||||||
f"index located at {save_index_file}."
|
f"index located at {save_index_file}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
else:
|
else:
|
||||||
# When pipeline is used, each stage produces its own shard files and index files.
|
# When pipeline is used, each stage produces its own shard files and index files.
|
||||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||||
@ -428,10 +432,10 @@ class MixtralMoEHybridParallelCheckpointIO(HybridParallelCheckpointIO):
|
|||||||
index_file.append_meta_data("total_size", total_size)
|
index_file.append_meta_data("total_size", total_size)
|
||||||
index_file.write_index_file(save_index_file)
|
index_file.write_index_file(save_index_file)
|
||||||
else:
|
else:
|
||||||
|
dist.barrier()
|
||||||
return
|
return
|
||||||
|
|
||||||
dist.barrier(self.pp_group)
|
dist.barrier()
|
||||||
dist.barrier(self.ep_group)
|
|
||||||
|
|
||||||
# The global master rank integrates the index files and clean the folder.
|
# The global master rank integrates the index files and clean the folder.
|
||||||
if self.coordinator.is_master():
|
if self.coordinator.is_master():
|
||||||
|
Loading…
Reference in New Issue
Block a user