mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
commit
c28b3c39db
@ -58,7 +58,7 @@ class BertPipelineForwards:
|
||||
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
||||
stage_index: Optional[List[int]] = None,
|
||||
shard_config: ShardConfig = None,
|
||||
):
|
||||
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||
# TODO(jianghai): add explaination of the output here.
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
@ -1037,6 +1037,90 @@ def get_jit_fused_bert_output_forward():
|
||||
return forward
|
||||
|
||||
|
||||
# Fix the tgt_len size in sequence parallel attention:
|
||||
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
|
||||
# _, _, tgt_len, _ = query_layer.shape
|
||||
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
||||
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
|
||||
|
||||
def forward(
|
||||
self: BertSdpaSelfAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
|
||||
bsz, tgt_len, _ = hidden_states.size()
|
||||
|
||||
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||
|
||||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
|
||||
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||
|
||||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||
key_layer, value_layer = past_key_value
|
||||
else:
|
||||
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||
if past_key_value is not None and not is_cross_attention:
|
||||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||
|
||||
if self.is_decoder:
|
||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||
# key/value_states (first "if" case)
|
||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||
past_key_value = (key_layer, value_layer)
|
||||
|
||||
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||
query_layer = query_layer.contiguous()
|
||||
key_layer = key_layer.contiguous()
|
||||
value_layer = value_layer.contiguous()
|
||||
|
||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||
# a causal mask in case tgt_len == 1.
|
||||
is_causal = (
|
||||
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||
)
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
_, _, tgt_len, _ = query_layer.shape
|
||||
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||
|
||||
outputs = (attn_output,)
|
||||
if self.is_decoder:
|
||||
outputs = outputs + (past_key_value,)
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||
def forward(
|
||||
self,
|
||||
|
@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
|
||||
from ..modeling.bert import (
|
||||
BertPipelineForwards,
|
||||
bert_sequence_parallel_forward_fn,
|
||||
get_bert_sequence_parallel_attention_forward,
|
||||
get_jit_fused_bert_intermediate_forward,
|
||||
get_jit_fused_bert_output_forward,
|
||||
get_jit_fused_bert_self_output_forward,
|
||||
@ -48,6 +49,7 @@ class BertPolicy(Policy):
|
||||
BertLayer,
|
||||
BertModel,
|
||||
BertOutput,
|
||||
BertSdpaSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
|
||||
@ -77,6 +79,16 @@ class BertPolicy(Policy):
|
||||
|
||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
# Fix the tgt_len size in bert sequence parallel attention forward.
|
||||
self.append_or_create_method_replacement(
|
||||
description={
|
||||
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=BertSdpaSelfAttention,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
|
Loading…
Reference in New Issue
Block a user