diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index dcb832639..b788cbf58 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1039,7 +1039,6 @@ def get_jit_fused_bert_output_forward(): # Fix the tgt_len size in sequence parallel attention: # same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the -# _, _, tgt_len, _ = query_layer.shape def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig): from transformers.models.bert.modeling_bert import BertSdpaSelfAttention