mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-18 07:57:46 +00:00
add explantion
This commit is contained in:
parent
2237531137
commit
d665d6740a
@ -1037,6 +1037,9 @@ def get_jit_fused_bert_output_forward():
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
# Fix the tgt_len size in sequence parallel attention:
|
||||||
|
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
|
||||||
|
# _, _, tgt_len, _ = query_layer.shape
|
||||||
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
||||||
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
|
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
|
||||||
|
|
||||||
|
@ -80,6 +80,7 @@ class BertPolicy(Policy):
|
|||||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
if self.shard_config.enable_sequence_parallelism:
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
# Fix the tgt_len size in bert sequence parallel attention forward.
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={
|
description={
|
||||||
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
|
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
|
||||||
|
Loading…
Reference in New Issue
Block a user