mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[moe] implement tp
This commit is contained in:
@@ -151,13 +151,10 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
|
||||
# ep_rank 0 saves all the parameters and buffers.
|
||||
# other ep_ranks save only experts
|
||||
ep_param_pattern = "experts." if self.ep_rank != 0 else None
|
||||
|
||||
# Then collect the sharded parameters & buffers along tp_group.
|
||||
# Only devices with tp_rank == 0 are responsible for model saving.
|
||||
state_dict_shard = MoECheckpointIO._model_sharder(
|
||||
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
|
||||
)
|
||||
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
|
||||
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
control_saving = self.tp_rank == 0
|
||||
|
Reference in New Issue
Block a user