From 3b5c314bea0c7947fd91e26731a427b2e536b8d4 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 1 Nov 2024 03:54:08 +0000 Subject: [PATCH] [fix] fix fp8 args in HybridParallel --- colossalai/booster/plugin/hybrid_parallel_plugin.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 752e8e1e8..1af20f473 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -722,8 +722,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer): overlap_allgather=overlap_allgather, fp8_communication=fp8_communication, backward_context=model._hook_context, - fp8_communication=fp8_communication, - backward_context=model._hook_context, ) def sync_dp_grads(self): @@ -1162,7 +1160,6 @@ class HybridParallelPlugin(PipelinePluginBase): enable_metadata_cache=enable_metadata_cache, overlap_p2p=overlap_p2p, fp8_communication=fp8_communication, - fp8_communication=fp8_communication, ) elif pp_style == "1f1b": self.scheduler = OneForwardOneBackwardSchedule( @@ -1213,7 +1210,6 @@ class HybridParallelPlugin(PipelinePluginBase): make_vocab_size_divisible_by=make_vocab_size_divisible_by, gradient_checkpoint_config=gradient_checkpoint_config, fp8_communication=fp8_communication, - fp8_communication=fp8_communication, inner_ring_size=inner_ring_size, pg_mesh=self.pg_mesh, sp_axis=self.sp_axis, @@ -1247,7 +1243,6 @@ class HybridParallelPlugin(PipelinePluginBase): forced_dtype=PRECISION_TORCH_TYPE[precision], overlap_allgather=overlap_allgather, fp8_communication=fp8_communication, - fp8_communication=fp8_communication, ) self.max_norm = max_norm