From cc500b3e25dc8d626829e0098a1cc54d6438f93b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 8 Oct 2024 09:34:09 +0000 Subject: [PATCH] [fix] fix mixtral policy; --- colossalai/shardformer/policies/mixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index c570badd6..af5b15ed5 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -269,7 +269,7 @@ class MixtralPolicy(Policy): for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( - stage_manager.is_last_stage(ignore_chunk=True) + not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True) ): # for zbv, when is_first_stage (last fwd), we append norm # for interleaved, when is_last_stage (last fwd), we also append norm