[moe] implement tp

This commit is contained in:
botbw
2024-07-16 06:03:57 +00:00
committed by Hongxin Liu
parent 0b5bbe9ce4
commit dc583aa576
8 changed files with 79 additions and 40 deletions

View File

@@ -151,13 +151,10 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
# ep_rank 0 saves all the parameters and buffers.
# other ep_ranks save only experts
ep_param_pattern = "experts." if self.ep_rank != 0 else None
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = MoECheckpointIO._model_sharder(
model, size_per_shard=size_per_shard, param_name_pattern=ep_param_pattern
)
state_dict_shard = MoECheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = self.tp_rank == 0