From 292a504bea0ca7af22d2f21c3826ca0a4ea7b4ab Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Tue, 8 Oct 2024 09:25:11 +0000 Subject: [PATCH] [fix] fix mixtral policy; --- colossalai/shardformer/policies/mixtral.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 3a41b2799..c570badd6 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -268,9 +268,11 @@ class MixtralPolicy(Policy): held_layers.append(module.embed_tokens) for start_idx, end_idx in stage_indices: held_layers.extend(module.layers[start_idx:end_idx]) - if stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True): - held_layers.append(module.norm) - elif stage_manager.is_last_stage(ignore_chunk=True): + if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or ( + stage_manager.is_last_stage(ignore_chunk=True) + ): + # for zbv, when is_first_stage (last fwd), we append norm + # for interleaved, when is_last_stage (last fwd), we also append norm held_layers.append(module.norm) else: layers_per_stage = stage_manager.distribute_layers(len(module.layers))