From 2237531137d0a1dd475cc693747e987ced76850c Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 13 May 2025 18:21:57 +0800 Subject: [PATCH 1/2] update_bloom --- colossalai/shardformer/modeling/bert.py | 83 ++++++++++++++++++++++++- colossalai/shardformer/policies/bert.py | 11 ++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 580f3618c..0a0db434e 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -58,7 +58,7 @@ class BertPipelineForwards: hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage stage_index: Optional[List[int]] = None, shard_config: ShardConfig = None, - ): + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: # TODO(jianghai): add explaination of the output here. r""" encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): @@ -1037,6 +1037,87 @@ def get_jit_fused_bert_output_forward(): return forward +def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig): + from transformers.models.bert.modeling_bert import BertSdpaSelfAttention + + def forward( + self: BertSdpaSelfAttention, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + _, _, tgt_len, _ = query_layer.shape + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + return forward + + def bert_sequence_parallel_forward_fn(shard_config: ShardConfig): def forward( self, diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 63cd49280..9f30622ee 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bert import ( BertPipelineForwards, bert_sequence_parallel_forward_fn, + get_bert_sequence_parallel_attention_forward, get_jit_fused_bert_intermediate_forward, get_jit_fused_bert_output_forward, get_jit_fused_bert_self_output_forward, @@ -48,6 +49,7 @@ class BertPolicy(Policy): BertLayer, BertModel, BertOutput, + BertSdpaSelfAttention, BertSelfOutput, ) @@ -77,6 +79,15 @@ class BertPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: + self.append_or_create_method_replacement( + description={ + "forward": get_bert_sequence_parallel_attention_forward(self.shard_config), + }, + policy=policy, + target_key=BertSdpaSelfAttention, + ) + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 From d665d6740abc95c82a39345a13aa2652e8e655c1 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 14 May 2025 10:15:25 +0800 Subject: [PATCH 2/2] add explantion --- colossalai/shardformer/modeling/bert.py | 3 +++ colossalai/shardformer/policies/bert.py | 1 + 2 files changed, 4 insertions(+) diff --git a/colossalai/shardformer/modeling/bert.py b/colossalai/shardformer/modeling/bert.py index 0a0db434e..dcb832639 100644 --- a/colossalai/shardformer/modeling/bert.py +++ b/colossalai/shardformer/modeling/bert.py @@ -1037,6 +1037,9 @@ def get_jit_fused_bert_output_forward(): return forward +# Fix the tgt_len size in sequence parallel attention: +# same with the one in BertSdpaSelfAttention forward in v4.51.3 transformers except the +# _, _, tgt_len, _ = query_layer.shape def get_bert_sequence_parallel_attention_forward(shard_config: ShardConfig): from transformers.models.bert.modeling_bert import BertSdpaSelfAttention diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index 9f30622ee..fd4e020b0 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -80,6 +80,7 @@ class BertPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv if self.shard_config.enable_sequence_parallelism: + # Fix the tgt_len size in bert sequence parallel attention forward. self.append_or_create_method_replacement( description={ "forward": get_bert_sequence_parallel_attention_forward(self.shard_config),