Update bert.py

This commit is contained in:
flybird11111 2025-05-27 10:57:06 +08:00 committed by GitHub
parent 17654cb6cb
commit 611c1247ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1039,7 +1039,6 @@ def get_jit_fused_bert_output_forward():
# Fix the tgt_len size in sequence parallel attention:
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
# _, _, tgt_len, _ = query_layer.shape
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention