diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index bdd7e2f8a..580f3618c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -3,10 +3,6 @@ from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -63,11 +59,12 @@ class BertPipelineForwards: stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): + # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: @@ -133,43 +130,13 @@ class BertPipelineForwards: # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) - use_sdpa_attention_masks = ( - self.attn_implementation == "sdpa" - and self.position_embedding_type == "absolute" - and head_mask is None - and not output_attentions - ) - - # Expand the attention mask - if use_sdpa_attention_masks and attention_mask.dim() == 2: - # Expand the attention mask for SDPA. - # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] - if self.config.is_decoder: - extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - embedding_output, - past_key_values_length, - ) - else: - extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: @@ -177,14 +144,7 @@ class BertPipelineForwards: encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: - # Expand the attention mask for SDPA. - # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] - encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -204,8 +164,7 @@ class BertPipelineForwards: inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - print("hidden_states:", hidden_states.shape) - print("bert_model_forward hidden_states:", hidden_states.shape) + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -252,25 +211,30 @@ class BertPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if self.encoder.gradient_checkpointing and self.encoder.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), hidden_states, - attention_mask=extended_attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, ) else: layer_outputs = encoder_layer( hidden_states, - attention_mask=extended_attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, ) hidden_states = layer_outputs[0] if use_cache: @@ -1176,7 +1140,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - print("embedding_output:", embedding_output.shape) + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] embedding_output = split_forward_gather_backward( @@ -1185,7 +1149,6 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): process_group=shard_config.tensor_parallel_process_group, fp8_communication=shard_config.fp8_communication, ) - print("after split_forward_gather_backward embedding_output:", embedding_output.shape) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( encoder_hidden_states,