diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 0a0db434e..dcb832639 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1037,6 +1037,9 @@ def get_jit_fused_bert_output_forward(): return forward +# Fix the tgt_len size in sequence parallel attention: +# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the +# _, _, tgt_len, _ = query_layer.shape def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig): from transformers.models.bert.modeling_bert import BertSdpaSelfAttention diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 9f30622ee..fd4e020b0 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -80,6 +80,7 @@ class BertPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_sequence_parallelism: + # Fix the tgt_len size in bert sequence parallel attention forward. self.append_or_create_method_replacement( description={ "forward": get_bert_sequence_parallel_attention_forward(self.shard_config),