mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
commit
c28b3c39db
@ -58,7 +58,7 @@ class BertPipelineForwards:
|
|||||||
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
shard_config: ShardConfig = None,
|
shard_config: ShardConfig = None,
|
||||||
):
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
||||||
# TODO(jianghai): add explaination of the output here.
|
# TODO(jianghai): add explaination of the output here.
|
||||||
r"""
|
r"""
|
||||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||||
@ -1037,6 +1037,90 @@ def get_jit_fused_bert_output_forward():
|
|||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
|
# Fix the tgt_len size in sequence parallel attention:
|
||||||
|
# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the
|
||||||
|
# _, _, tgt_len, _ = query_layer.shape
|
||||||
|
def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig):
|
||||||
|
from transformers.models.bert.modeling_bert import BertSdpaSelfAttention
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self: BertSdpaSelfAttention,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Tuple[torch.Tensor]:
|
||||||
|
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_layer = self.transpose_for_scores(self.query(hidden_states))
|
||||||
|
|
||||||
|
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
|
||||||
|
# mask needs to be such that the encoder's padding tokens are not attended to.
|
||||||
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states
|
||||||
|
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
|
||||||
|
|
||||||
|
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
|
||||||
|
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
|
||||||
|
key_layer, value_layer = past_key_value
|
||||||
|
else:
|
||||||
|
key_layer = self.transpose_for_scores(self.key(current_states))
|
||||||
|
value_layer = self.transpose_for_scores(self.value(current_states))
|
||||||
|
if past_key_value is not None and not is_cross_attention:
|
||||||
|
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
||||||
|
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
||||||
|
|
||||||
|
if self.is_decoder:
|
||||||
|
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
||||||
|
# Further calls to cross_attention layer can then reuse all cross-attention
|
||||||
|
# key/value_states (first "if" case)
|
||||||
|
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
||||||
|
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
||||||
|
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
||||||
|
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
||||||
|
past_key_value = (key_layer, value_layer)
|
||||||
|
|
||||||
|
# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
|
||||||
|
# attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0.
|
||||||
|
# Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
|
if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None:
|
||||||
|
query_layer = query_layer.contiguous()
|
||||||
|
key_layer = key_layer.contiguous()
|
||||||
|
value_layer = value_layer.contiguous()
|
||||||
|
|
||||||
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
|
||||||
|
# a causal mask in case tgt_len == 1.
|
||||||
|
is_causal = (
|
||||||
|
True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False
|
||||||
|
)
|
||||||
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
|
query_layer,
|
||||||
|
key_layer,
|
||||||
|
value_layer,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=self.dropout_prob if self.training else 0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
_, _, tgt_len, _ = query_layer.shape
|
||||||
|
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
|
||||||
|
|
||||||
|
outputs = (attn_output,)
|
||||||
|
if self.is_decoder:
|
||||||
|
outputs = outputs + (past_key_value,)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn
|
|||||||
from ..modeling.bert import (
|
from ..modeling.bert import (
|
||||||
BertPipelineForwards,
|
BertPipelineForwards,
|
||||||
bert_sequence_parallel_forward_fn,
|
bert_sequence_parallel_forward_fn,
|
||||||
|
get_bert_sequence_parallel_attention_forward,
|
||||||
get_jit_fused_bert_intermediate_forward,
|
get_jit_fused_bert_intermediate_forward,
|
||||||
get_jit_fused_bert_output_forward,
|
get_jit_fused_bert_output_forward,
|
||||||
get_jit_fused_bert_self_output_forward,
|
get_jit_fused_bert_self_output_forward,
|
||||||
@ -48,6 +49,7 @@ class BertPolicy(Policy):
|
|||||||
BertLayer,
|
BertLayer,
|
||||||
BertModel,
|
BertModel,
|
||||||
BertOutput,
|
BertOutput,
|
||||||
|
BertSdpaSelfAttention,
|
||||||
BertSelfOutput,
|
BertSelfOutput,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -77,6 +79,16 @@ class BertPolicy(Policy):
|
|||||||
|
|
||||||
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
|
||||||
|
|
||||||
|
if self.shard_config.enable_sequence_parallelism:
|
||||||
|
# Fix the tgt_len size in bert sequence parallel attention forward.
|
||||||
|
self.append_or_create_method_replacement(
|
||||||
|
description={
|
||||||
|
"forward": get_bert_sequence_parallel_attention_forward(self.shard_config),
|
||||||
|
},
|
||||||
|
policy=policy,
|
||||||
|
target_key=BertSdpaSelfAttention,
|
||||||
|
)
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
Loading…
Reference in New Issue
Block a user