From 5d167f2148c7ebcbff3eafe7464e3212550172ea Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 28 Apr 2025 18:01:53 +0800 Subject: [PATCH 1/8] fix --- colossalai/shardformer/modeling/falcon.py | 132 ++++++++-------------- colossalai/shardformer/policies/falcon.py | 1 + 2 files changed, 51 insertions(+), 82 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 8181a68a0..c2802063f 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -11,6 +11,7 @@ from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -110,19 +111,18 @@ def get_tp_falcon_decoder_layer_forward(): alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ): - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) + residual = hidden_states - if self.config.new_decoder_architecture: + if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -138,7 +138,8 @@ def get_tp_falcon_decoder_layer_forward(): head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - **kwargs, + cache_position=cache_position, + position_embeddings=position_embeddings, ) attention_output = attn_outputs[0] @@ -152,6 +153,13 @@ def get_tp_falcon_decoder_layer_forward(): ) mlp_layernorm_out = self.post_attention_layernorm(residual) + if ( + self.config.new_decoder_architecture + and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1 + ): + mlp_layernorm_out = attention_layernorm_out + outputs = attn_outputs[1:] # MLP. @@ -167,7 +175,7 @@ def get_tp_falcon_decoder_layer_forward(): else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, present, attentions + return outputs # hidden_states, past_kv, attentions return forward @@ -190,6 +198,7 @@ class FalconPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -206,9 +215,9 @@ class FalconPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - if past_key_values is not None: - logger.warning_once("past_key_values is not supported for pipeline models at the moment.") - past_key_values = None + + logger.warning_once("past_key_values is not supported for pipeline models at the moment.") + past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -221,7 +230,7 @@ class FalconPipelineForwards: elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = inputs_embeds @@ -229,12 +238,9 @@ class FalconPipelineForwards: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - if self.gradient_checkpointing and self.training: if use_cache: - logger.warning( + logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -243,10 +249,10 @@ class FalconPipelineForwards: all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation + alibi = None past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[-2] - + + batch_size, seq_length, _ = hidden_states.shape if self.use_alibi: mask = ( torch.ones( @@ -256,73 +262,30 @@ class FalconPipelineForwards: else attention_mask ) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) - else: - alibi = None - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device - ) - position_ids = position_ids.unsqueeze(0) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - if alibi is None: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - elif head_mask is None: - alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - - attention_mask_2d = attention_mask - # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - - # We take care to integrate alibi bias in the attention_mask here. - if attention_mask_2d is None: - attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) - else: - min_dtype = torch.finfo(alibi.dtype).min - attention_mask = torch.masked_fill( - alibi / math.sqrt(self.config.hidden_size // self.num_heads), - attention_mask < -1, - min_dtype, - ) - - # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend - # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 - if seq_length > 1 and attention_mask.device.type == "cuda": - attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) - else: - # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi + ) + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate( - zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx - ): + for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -331,28 +294,32 @@ class FalconPipelineForwards: block.__call__, hidden_states, alibi, - attention_mask, + causal_mask, position_ids, head_mask[i], - layer_past, + past_key_values, use_cache, output_attentions, + cache_position, + position_embeddings, ) else: outputs = block( hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, + layer_past=past_key_values, + attention_mask=causal_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, + position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: - presents = presents + (outputs[1],) + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -365,6 +332,7 @@ class FalconPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): + if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 68a548aee..362f33176 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -246,6 +246,7 @@ class FalconPolicy(Policy): module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] + held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.h)) From 08787f0b6ec5571391432d3688a6162a739ba38e Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 5 May 2025 09:50:07 +0800 Subject: [PATCH 2/8] upgrade_bert --- colossalai/shardformer/modeling/bert.py | 97 +++++++++++++++++-------- 1 file changed, 67 insertions(+), 30 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 580f3618c..bdd7e2f8a 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -3,6 +3,10 @@ from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -59,12 +63,11 @@ class BertPipelineForwards: stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): - # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: @@ -130,13 +133,43 @@ class BertPipelineForwards: # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if attention_mask is None: - attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) + use_sdpa_attention_masks = ( + self.attn_implementation == "sdpa" + and self.position_embedding_type == "absolute" + and head_mask is None + and not output_attentions + ) + + # Expand the attention mask + if use_sdpa_attention_masks and attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + if self.config.is_decoder: + extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + embedding_output, + past_key_values_length, + ) + else: + extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) - attention_mask = extended_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: @@ -144,7 +177,14 @@ class BertPipelineForwards: encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: + # Expand the attention mask for SDPA. + # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] + encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -164,7 +204,8 @@ class BertPipelineForwards: inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - + print("hidden_states:", hidden_states.shape) + print("bert_model_forward hidden_states:", hidden_states.shape) # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -211,30 +252,25 @@ class BertPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if self.encoder.gradient_checkpointing and self.encoder.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, past_key_value, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(encoder_layer), + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, + attention_mask=extended_attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=extended_attention_mask, + head_mask=layer_head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if use_cache: @@ -1140,7 +1176,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - + print("embedding_output:", embedding_output.shape) # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] embedding_output = split_forward_gather_backward( @@ -1149,6 +1185,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): process_group=shard_config.tensor_parallel_process_group, fp8_communication=shard_config.fp8_communication, ) + print("after split_forward_gather_backward embedding_output:", embedding_output.shape) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( encoder_hidden_states, From 5480b811c5ee4c90891e61d0f35c4e909dd8a17a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 6 May 2025 15:58:53 +0800 Subject: [PATCH 3/8] upgrade_bloom --- colossalai/shardformer/modeling/bloom.py | 224 ++++++++++++----------- 1 file changed, 120 insertions(+), 104 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 7e8e50d9b..1e8b8b3e2 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -6,7 +6,7 @@ import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import functional as F -from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -108,7 +108,7 @@ class BloomPipelineForwards: def bloom_model_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -116,6 +116,7 @@ class BloomPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -151,6 +152,8 @@ class BloomPipelineForwards: if use_cache: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False + past_key_values = None + # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N @@ -161,46 +164,60 @@ class BloomPipelineForwards: # case: First stage of training if stage_manager.is_first_stage(): # check input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - # initialize in the first stage and then pass to the next stage - else: - input_shape = hidden_states.shape[:-1] - batch_size, seq_length = input_shape - - # extra recording tensor should be generated in the first stage - - presents = () if use_cache else None - all_self_attentions = () if output_attentions else None - all_hidden_states = () if output_hidden_states else None - - if self.gradient_checkpointing and self.training: - if use_cache: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) - # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] # source_len + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) + # initialize in the first stage and then pass to the next stage + else: + input_shape = hidden_states.shape[:-1] + batch_size, seq_length = input_shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=hidden_states.device) + + # extra recording tensor should be generated in the first stage + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + # Compute alibi tensor: check build_alibi_tensor documentation,build for every stage + past_length = 0 + seq_length_with_past = seq_length + past_length - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: @@ -209,13 +226,10 @@ class BloomPipelineForwards: alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) # causal_mask is constructed every stage and its input is passed through different stages - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, + causal_mask = self._update_causal_mask( + attention_mask, hidden_states, cache_position, past_key_values, output_attentions ) - causal_mask = causal_mask.bool() + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] if shard_config and shard_config.enable_sequence_parallelism: @@ -228,9 +242,7 @@ class BloomPipelineForwards: ) start_idx, end_idx = stage_index[0], stage_index[1] - for i, (block, layer_past) in enumerate( - zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx - ): + for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -240,26 +252,28 @@ class BloomPipelineForwards: hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] + if use_cache: + next_decoder_cache = outputs[1] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -277,20 +291,23 @@ class BloomPipelineForwards: # Add last hidden state hidden_states = self.ln_f(hidden_states) - # TODO(jianghai): deal with all_hidden_states, all_self_attentions, presents if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if stage_manager.is_last_stage(): if not return_dict: return tuple( - v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None ) # attention_mask is not returned ; presents = past_key_values return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @@ -845,7 +862,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, @@ -853,6 +870,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop("position_ids", False) is not False: @@ -864,7 +882,6 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): ) if len(deprecated_arguments) > 0: raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -872,62 +889,60 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.h)) + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + batch_size, seq_length, _ = inputs_embeds.shape + past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + seq_length_with_past = seq_length + past_length + if cache_position is None: + cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.n_layer) - - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(inputs_embeds) - presents = () if use_cache else None + next_decoder_cache = None all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Compute alibi tensor: check build_alibi_tensor documentation - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) else: attention_mask = attention_mask.to(hidden_states.device) alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) - - causal_mask = _prepare_4d_causal_attention_mask( - attention_mask, - input_shape=(batch_size, seq_length), - inputs_embeds=hidden_states, - past_key_values_length=past_key_values_length, + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) - causal_mask = causal_mask.bool() - # split the input tensor along sequence dimension - # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] + hidden_states = split_forward_gather_backward( hidden_states, dim=1, @@ -935,7 +950,7 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): fp8_communication=shard_config.fp8_communication, ) - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -945,48 +960,49 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): hidden_states, alibi, causal_mask, - layer_past, + past_key_values, head_mask[i], use_cache, output_attentions, + cache_position, ) else: outputs = block( hidden_states, - layer_past=layer_past, + layer_past=past_key_values, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, + cache_position=cache_position, ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) + if use_cache: + next_decoder_cache = outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) - # When sequence parallelism done, gather the output tensor in forward and split it in backward - hidden_states = gather_forward_split_backward( - hidden_states, - dim=1, - process_group=shard_config.tensor_parallel_process_group, - fp8_communication=shard_config.fp8_communication, - ) # Add last hidden state hidden_states = self.ln_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: - return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attentions] if v is not None + ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) From 4eced5cf8a07bf70bb73dc6150351ed89d3e2af9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 8 May 2025 09:58:04 +0000 Subject: [PATCH 4/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/falcon.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index c2802063f..27461be04 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,17 +1,9 @@ -import math -import warnings from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_attn_mask_utils import ( - AttentionMaskConverter, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -159,7 +151,7 @@ def get_tp_falcon_decoder_layer_forward(): and self.config.num_ln_in_parallel_attn == 1 ): mlp_layernorm_out = attention_layernorm_out - + outputs = attn_outputs[1:] # MLP. @@ -215,7 +207,6 @@ class FalconPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - logger.warning_once("past_key_values is not supported for pipeline models at the moment.") past_key_values = None @@ -251,7 +242,7 @@ class FalconPipelineForwards: # Compute alibi tensor: check build_alibi_tensor documentation alibi = None past_key_values_length = 0 - + batch_size, seq_length, _ = hidden_states.shape if self.use_alibi: mask = ( @@ -262,7 +253,7 @@ class FalconPipelineForwards: else attention_mask ) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) - + if cache_position is None: cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device @@ -280,7 +271,7 @@ class FalconPipelineForwards: # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - + # create position embeddings to be shared across the decoder layers position_embeddings = self.rotary_emb(hidden_states, position_ids) @@ -319,7 +310,7 @@ class FalconPipelineForwards: hidden_states = outputs[0] if use_cache is True: - next_decoder_cache = outputs[1] + outputs[1] if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -332,7 +323,7 @@ class FalconPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): - + if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None From fe94d73f6b0c0e431b3a8df57e9663911c4821ba Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 8 May 2025 18:03:53 +0800 Subject: [PATCH 5/8] fix --- colossalai/shardformer/modeling/bert.py | 95 ++++++++----------------- 1 file changed, 29 insertions(+), 66 deletions(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index bdd7e2f8a..580f3618c 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -3,10 +3,6 @@ from typing import List, Optional, Tuple, Union import torch from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask_for_sdpa, -) from transformers.modeling_outputs import ( BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -63,11 +59,12 @@ class BertPipelineForwards: stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, ): + # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder. - encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*): + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: @@ -133,43 +130,13 @@ class BertPipelineForwards: # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - token_type_ids=token_type_ids, - inputs_embeds=inputs_embeds, - past_key_values_length=past_key_values_length, - ) - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) - use_sdpa_attention_masks = ( - self.attn_implementation == "sdpa" - and self.position_embedding_type == "absolute" - and head_mask is None - and not output_attentions - ) - - # Expand the attention mask - if use_sdpa_attention_masks and attention_mask.dim() == 2: - # Expand the attention mask for SDPA. - # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] - if self.config.is_decoder: - extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - embedding_output, - past_key_values_length, - ) - else: - extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: @@ -177,14 +144,7 @@ class BertPipelineForwards: encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) - if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: - # Expand the attention mask for SDPA. - # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len] - encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length - ) - else: - encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None @@ -204,8 +164,7 @@ class BertPipelineForwards: inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - print("hidden_states:", hidden_states.shape) - print("bert_model_forward hidden_states:", hidden_states.shape) + # inherit from bert_layer,this should be changed when we add the feature to record hidden_states all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -252,25 +211,30 @@ class BertPipelineForwards: past_key_value = past_key_values[idx] if past_key_values is not None else None if self.encoder.gradient_checkpointing and self.encoder.training: - layer_outputs = self._gradient_checkpointing_func( - encoder_layer.__call__, + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), hidden_states, - attention_mask=extended_attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, ) else: layer_outputs = encoder_layer( hidden_states, - attention_mask=extended_attention_mask, - head_mask=layer_head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, - output_attentions=output_attentions, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, ) hidden_states = layer_outputs[0] if use_cache: @@ -1176,7 +1140,7 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) - print("embedding_output:", embedding_output.shape) + # split the input tensor along sequence dimension # [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size] embedding_output = split_forward_gather_backward( @@ -1185,7 +1149,6 @@ def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): process_group=shard_config.tensor_parallel_process_group, fp8_communication=shard_config.fp8_communication, ) - print("after split_forward_gather_backward embedding_output:", embedding_output.shape) if encoder_hidden_states is not None: encoder_hidden_states = split_forward_gather_backward( encoder_hidden_states, From b124603c6898e9f3d2c26672b95a04dc56ac9f5e Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 8 May 2025 18:06:56 +0800 Subject: [PATCH 6/8] fix --- colossalai/shardformer/modeling/falcon.py | 133 ++++++++++++++-------- colossalai/shardformer/policies/falcon.py | 1 - 2 files changed, 87 insertions(+), 47 deletions(-) diff --git a/colossalai/shardformer/modeling/falcon.py b/colossalai/shardformer/modeling/falcon.py index 27461be04..8181a68a0 100644 --- a/colossalai/shardformer/modeling/falcon.py +++ b/colossalai/shardformer/modeling/falcon.py @@ -1,9 +1,16 @@ +import math +import warnings from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.distributed import ProcessGroup from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -103,18 +110,19 @@ def get_tp_falcon_decoder_layer_forward(): alibi: Optional[torch.Tensor], attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC **kwargs, ): - + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states - if self.config.new_decoder_architecture and self.config.num_ln_in_parallel_attn == 2: + if self.config.new_decoder_architecture: attention_layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: @@ -130,8 +138,7 @@ def get_tp_falcon_decoder_layer_forward(): head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - cache_position=cache_position, - position_embeddings=position_embeddings, + **kwargs, ) attention_output = attn_outputs[0] @@ -145,13 +152,6 @@ def get_tp_falcon_decoder_layer_forward(): ) mlp_layernorm_out = self.post_attention_layernorm(residual) - if ( - self.config.new_decoder_architecture - and self.config.parallel_attn - and self.config.num_ln_in_parallel_attn == 1 - ): - mlp_layernorm_out = attention_layernorm_out - outputs = attn_outputs[1:] # MLP. @@ -167,7 +167,7 @@ def get_tp_falcon_decoder_layer_forward(): else: outputs = (output,) + outputs[1:] - return outputs # hidden_states, past_kv, attentions + return outputs # hidden_states, present, attentions return forward @@ -190,7 +190,6 @@ class FalconPipelineForwards: output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, stage_manager: Optional[PipelineStageManager] = None, hidden_states: Optional[torch.FloatTensor] = None, stage_index: Optional[List[int]] = None, @@ -207,8 +206,9 @@ class FalconPipelineForwards: logger.warning_once("use_cache=True is not supported for pipeline models at the moment.") use_cache = False - logger.warning_once("past_key_values is not supported for pipeline models at the moment.") - past_key_values = None + if past_key_values is not None: + logger.warning_once("past_key_values is not supported for pipeline models at the moment.") + past_key_values = None return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -221,7 +221,7 @@ class FalconPipelineForwards: elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) hidden_states = inputs_embeds @@ -229,9 +229,12 @@ class FalconPipelineForwards: input_shape = hidden_states.shape[:-1] batch_size, seq_length = input_shape + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( + logger.warning( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False @@ -240,10 +243,10 @@ class FalconPipelineForwards: all_hidden_states = () if output_hidden_states else None # Compute alibi tensor: check build_alibi_tensor documentation - alibi = None past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[-2] - batch_size, seq_length, _ = hidden_states.shape if self.use_alibi: mask = ( torch.ones( @@ -253,30 +256,73 @@ class FalconPipelineForwards: else attention_mask ) alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) + else: + alibi = None + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) - if cache_position is None: - cache_position = torch.arange( - past_key_values_length, past_key_values_length + seq_length, device=hidden_states.device + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self._use_sdpa and not output_attentions: + # output_attentions=True can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + if alibi is None: + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + elif head_mask is None: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + + attention_mask_2d = attention_mask + # We don't call _prepare_4d_causal_attention_mask_for_sdpa as we need to mask alibi using the 4D attention_mask untouched. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + # We take care to integrate alibi bias in the attention_mask here. + if attention_mask_2d is None: + attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads) + else: + min_dtype = torch.finfo(alibi.dtype).min + attention_mask = torch.masked_fill( + alibi / math.sqrt(self.config.hidden_size // self.num_heads), + attention_mask < -1, + min_dtype, + ) + + # From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend + # produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213 + if seq_length > 1 and attention_mask.device.type == "cuda": + attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype) + else: + # PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case. + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, hidden_states, cache_position, past_key_values, output_attentions, head_mask, alibi - ) - # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape batch_size x num_heads x N x N # head_mask has shape n_layer x batch x num_heads x N x N head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - start_idx, end_idx = stage_index[0], stage_index[1] - for i, block in enumerate(self.h[start_idx:end_idx], start=start_idx): + for i, (block, layer_past) in enumerate( + zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]), start=start_idx + ): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -285,32 +331,28 @@ class FalconPipelineForwards: block.__call__, hidden_states, alibi, - causal_mask, + attention_mask, position_ids, head_mask[i], - past_key_values, + layer_past, use_cache, output_attentions, - cache_position, - position_embeddings, ) else: outputs = block( hidden_states, - layer_past=past_key_values, - attention_mask=causal_mask, + layer_past=layer_past, + attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, - cache_position=cache_position, - position_embeddings=position_embeddings, ) hidden_states = outputs[0] if use_cache is True: - outputs[1] + presents = presents + (outputs[1],) if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) @@ -323,7 +365,6 @@ class FalconPipelineForwards: all_hidden_states = all_hidden_states + (hidden_states,) if stage_manager.is_last_stage(): - if not return_dict: return tuple( v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index 362f33176..68a548aee 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -246,7 +246,6 @@ class FalconPolicy(Policy): module = self.model.transformer stage_manager = self.pipeline_stage_manager held_layers = [] - held_layers.append(module.rotary_emb) if stage_manager.is_interleave: assert stage_manager.num_model_chunks is not None layers_per_stage = stage_manager.distribute_layers(len(module.h)) From d6f3508910e5bd10dbbfe6122730c8e193b38ff4 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 13 May 2025 10:15:48 +0800 Subject: [PATCH 7/8] fix --- colossalai/shardformer/modeling/bloom.py | 87 ++++++++++++++++++++++++ colossalai/shardformer/policies/bloom.py | 10 +++ 2 files changed, 97 insertions(+) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1e8b8b3e2..f737e3b5e 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -21,6 +21,7 @@ from transformers.models.bloom.modeling_bloom import ( BloomForSequenceClassification, BloomForTokenClassification, BloomModel, + dropout_add, ) from transformers.utils import logging @@ -856,6 +857,92 @@ def get_jit_fused_bloom_gelu_forward(): return forward +# Fixed the q_length args when doing the sequence parallelism in bloom model. +def get_bloom_sequence_parallel_attention_forward(shard_config: ShardConfig): + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + batch_size, q_length, _ = hidden_states.shape + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + + # [batch_size * num_heads, q_length, kv_length] + attention_scores = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + if shard_config.enable_sequence_parallelism: + _, q_length, _ = query_layer.shape + # change view to [batch_size, num_heads, q_length, kv_length] + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, layer_past) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): from transformers import BloomModel diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index c7691698b..af49a4d19 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, + get_bloom_sequence_parallel_attention_forward, get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, @@ -61,6 +62,15 @@ class BloomPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: + self.append_or_create_method_replacement( + description={ + "forward": get_bloom_sequence_parallel_attention_forward(self.shard_config), + }, + policy=policy, + target_key=BloomAttention, + ) + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 From 4fbbf4737a1007a63b8158b6114f2bb13922d74a Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 13 May 2025 14:51:54 +0800 Subject: [PATCH 8/8] fix --- colossalai/shardformer/modeling/bloom.py | 63 +++++++++++------------- 1 file changed, 29 insertions(+), 34 deletions(-) diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index f737e3b5e..5ca8f9869 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -736,35 +736,24 @@ def get_jit_fused_bloom_attention_forward(): head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ): + batch_size, q_length, _ = hidden_states.shape fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, q_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, head_dim, kv_length] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=2) - value_layer = torch.cat((past_value, value_layer), dim=1) + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) - _, _, kv_length = key_layer.shape - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # [batch_size * num_heads, q_length, kv_length] - # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 - matmul_result = alibi.baddbmm( + attention_scores = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, @@ -772,15 +761,13 @@ def get_jit_fused_bloom_attention_forward(): ) # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16: - attention_scores = attention_scores.to(torch.float) - attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) - attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) @@ -789,12 +776,12 @@ def get_jit_fused_bloom_attention_forward(): attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) - # change view [batch_size, num_heads, q_length, head_dim] + # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 @@ -809,9 +796,9 @@ def get_jit_fused_bloom_attention_forward(): else: output_tensor = self.dense(context_layer) - output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output_tensor, present) + outputs = (output_tensor, layer_past) if output_attentions: outputs += (attention_probs,) @@ -1072,6 +1059,14 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, + ) + # Add last hidden state hidden_states = self.ln_f(hidden_states)