diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 8181a68a0..27461be04 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,16 +1,9 @@ -import math -import warnings from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -110,19 +103,18 @@ def get_tp_falcon_decoder_layer_forward(): alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -138,7 +130,8 @@ def get_tp_falcon_decoder_layer_forward(): head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - **kwargs, + cache_position=cache_position, + position_embeddings=position_embeddings, ) attention_output = attn_outputs[0] @@ -152,6 +145,13 @@ def get_tp_falcon_decoder_layer_forward(): ) mlp_layernorm_out = self.post_attention_layernorm(residual) + if ( + self.config.new_decoder_architecture + and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1 + ): + mlp_layernorm_out = attention_layernorm_out + outputs = attn_outputs[1:] # MLP. @@ -167,7 +167,7 @@ def get_tp_falcon_decoder_layer_forward(): else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, present, attentions + return outputs # hidden_states, past_kv, attentions return forward @@ -190,6 +190,7 @@ class FalconPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -206,9 +207,8 @@ class FalconPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if past_key_values is not None: - logger.warning_once("past_key_values is not supported for pipeline models at the moment.") - past_key_values = None + logger.warning_once("past_key_values is not supported for pipeline models at the moment.") + past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -221,7 +221,7 @@ class FalconPipelineForwards: elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = inputs_embeds @@ -229,12 +229,9 @@ class FalconPipelineForwards: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - if self.gradient_checkpointing and self.training: if use_cache: - logger.warning( + logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -243,10 +240,10 @@ class FalconPipelineForwards: all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation + alibi = None past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-2] + batch_size, seq_length, _ = hidden_states.shape if self.use_alibi: mask = ( torch.ones( @@ -256,73 +253,30 @@ class FalconPipelineForwards: else attention_mask ) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) - else: - alibi = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - if alibi is None: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - elif head_mask is None: - alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - - attention_mask_2d = attention_mask - # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - # We take care to integrate alibi bias in the attention_mask here. - if attention_mask_2d is None: - attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) - else: - min_dtype = torch.finfo(alibi.dtype).min - attention_mask = torch.masked_fill( - alibi / math.sqrt(self.config.hidden_size // self.num_heads), - attention_mask < -1, - min_dtype, - ) - - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1 and attention_mask.device.type == "cuda": - attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) - else: - # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi + ) + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate( - zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx - ): + for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -331,28 +285,32 @@ class FalconPipelineForwards: block.__call__, hidden_states, alibi, - attention_mask, + causal_mask, position_ids, head_mask[i], - layer_past, + past_key_values, use_cache, output_attentions, + cache_position, + position_embeddings, ) else: outputs = block( hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, + layer_past=past_key_values, + attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: - presents = presents + (outputs[1],) + outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -365,6 +323,7 @@ class FalconPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): + if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 68a548aee..362f33176 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -246,6 +246,7 @@ class FalconPolicy(Policy): module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.h))