add explantion

This commit is contained in:
wangbluo 2025-05-14 10:15:25 +08:00
parent 2237531137
commit d665d6740a
2 changed files with 4 additions and 0 deletions

View File

@ -1037,6 +1037,9 @@ def get_jit_fused_bert_output_forward():
return forward
# Fix the tgt_len size in sequence parallel attention:
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
# _, _, tgt_len, _ = query_layer.shape
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention

View File

@ -80,6 +80,7 @@ class BertPolicy(Policy):
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_sequence_parallelism:
# Fix the tgt_len size in bert sequence parallel attention forward.
self.append_or_create_method_replacement(
description={
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),