mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
[moe] init mixtral impl
This commit is contained in:
@@ -181,6 +181,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
overlap_communication: bool = True,
|
||||
use_ep_inside: bool = True,
|
||||
custom_policy: Policy = None,
|
||||
checkpoint_io: Optional[MoECheckpintIO] = None,
|
||||
) -> None:
|
||||
assert (
|
||||
dist.get_world_size() % (tp_size * pp_size) == 0
|
||||
@@ -200,6 +201,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.checkpoint_io = checkpoint_io
|
||||
# we change pg mesh to (pp, dp, tp) for better moe performance
|
||||
self.pg_mesh = ProcessGroupMesh(self.pp_size, self.dp_size, self.tp_size)
|
||||
|
||||
@@ -323,7 +325,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
)
|
||||
|
||||
def get_checkpoint_io(self) -> MoECheckpintIO:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
if self.checkpoint_io is None:
|
||||
self.checkpoint_io = MoECheckpintIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
else:
|
||||
self.checkpoint_io = self.checkpoint_io(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
|
||||
def configure(
|
||||
|
Reference in New Issue
Block a user