diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 85895820e..10df143c9 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -61,7 +61,7 @@ class MixtralPolicy(Policy): policy[attn_cls] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) - if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism: + if self.shard_config.enable_sequence_parallelism: if self.pipeline_stage_manager is not None: # NOTE: we are replacing model forward for both sequence parallelism and pipeline parallelism # if both are enabled, one of them will be ignored