Merge pull request #6305 from wangbluo/update_bert

update_bert
This commit is contained in:
Hanks 2025-05-14 10:19:34 +08:00 committed by GitHub
commit c28b3c39db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 97 additions and 1 deletions

View File

@ -58,7 +58,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
# TODO(jianghai): add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -1037,6 +1037,90 @@ def get_jit_fused_bert_output_forward():
return forward
# Fix the tgt_len size in sequence parallel attention:
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
# _, _, tgt_len, _ = query_layer.shape
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
def forward(
self: BertSdpaSelfAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
bsz, tgt_len, _ = hidden_states.size()
query_layer = self.transpose_for_scores(self.query(hidden_states))
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None
current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
# Reference: https://github.com/pytorch/pytorch/issues/112577
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
query_layer = query_layer.contiguous()
key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
# a causal mask in case tgt_len == 1.
is_causal = (
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attention_mask,
dropout_p=self.dropout_prob if self.training else 0.0,
is_causal=is_causal,
)
attn_output = attn_output.transpose(1, 2)
_, _, tgt_len, _ = query_layer.shape
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
outputs = (attn_output,)
if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs
return forward
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,

View File

@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_sequence_parallel_attention_forward,
get_jit_fused_bert_intermediate_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@ -48,6 +49,7 @@ class BertPolicy(Policy):
BertLayer,
BertModel,
BertOutput,
BertSdpaSelfAttention,
BertSelfOutput,
)
@ -77,6 +79,16 @@ class BertPolicy(Policy):
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
if self.shard_config.enable_sequence_parallelism:
# Fix the tgt_len size in bert sequence parallel attention forward.
self.append_or_create_method_replacement(
description={
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
},
policy=policy,
target_key=BertSdpaSelfAttention,
)
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0