diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 580f3618c..bdd7e2f8a 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -3,6 +3,10 @@ from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -59,12 +63,11 @@ class BertPipelineForwards: stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): - # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: @@ -130,13 +133,43 @@ class BertPipelineForwards: # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: @@ -144,7 +177,14 @@ class BertPipelineForwards: encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -164,7 +204,8 @@ class BertPipelineForwards: inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - + print("hidden_states:", hidden_states.shape) + print("bert_model_forward hidden_states:", hidden_states.shape) # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -211,30 +252,25 @@ class BertPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if self.encoder.gradient_checkpointing and self.encoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, + attention_mask=extended_attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=extended_attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if use_cache: @@ -1140,7 +1176,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - + print("embedding_output:", embedding_output.shape) # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] embedding_output = split_forward_gather_backward( @@ -1149,6 +1185,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): process_group=shard_config.tensor_parallel_process_group, fp8_communication=shard_config.fp8_communication, ) + print("after split_forward_gather_backward embedding_output:", embedding_output.shape) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( encoder_hidden_states,