From 3f09a6145f7c68350a508296c36fa77814e6010d Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Fri, 16 Aug 2024 10:12:50 +0800 Subject: [PATCH] [fp8] add use_fp8 option for MoeHybridParallelPlugin (#6009) --- colossalai/booster/plugin/moe_hybrid_parallel_plugin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index ca3a68373..374fc6535 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -215,6 +215,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_p2p: bool = True, overlap_allgather: bool = False, fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: if overlap_communication or zero_stage == 2: overlap_communication = False @@ -324,7 +325,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.sp_group = self.pg_mesh.get_group_along_axis(self.tp_axis) else: self.sp_group = self.pg_mesh.get_group_along_axis(self.sp_axis) - + self.use_fp8 = use_fp8 self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, sequence_parallel_process_group=self.sp_group, @@ -428,6 +429,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): use_ddp=use_ddp, ddp_config=self.ddp_config, custom_policy=self.custom_policy, + use_fp8=self.use_fp8, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.ep_size > 1: