diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 1e8b8b3e2..f737e3b5e 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -21,6 +21,7 @@ from transformers.models.bloom.modeling_bloom import ( BloomForSequenceClassification, BloomForTokenClassification, BloomModel, + dropout_add, ) from transformers.utils import logging @@ -856,6 +857,92 @@ def get_jit_fused_bloom_gelu_forward(): return forward +# Fixed the q_length args when doing the sequence parallelism in bloom model. +def get_bloom_sequence_parallel_attention_forward(shard_config: ShardConfig): + from transformers.models.bloom.modeling_bloom import BloomAttention + + def forward( + self: BloomAttention, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + batch_size, q_length, _ = hidden_states.shape + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + + # [batch_size * num_heads, q_length, kv_length] + attention_scores = alibi.baddbmm( + batch1=query_layer, + batch2=key_layer, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + if shard_config.enable_sequence_parallelism: + _, q_length, _ = query_layer.shape + # change view to [batch_size, num_heads, q_length, kv_length] + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) + + # [batch_size, num_heads, q_length, kv_length] + attention_probs = self.attention_dropout(attention_probs) + + if head_mask is not None: + attention_probs = attention_probs * head_mask + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) + + # matmul: [batch_size * num_heads, q_length, head_dim] + context_layer = torch.bmm(attention_probs_reshaped, value_layer) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + + outputs = (output_tensor, layer_past) + if output_attentions: + outputs += (attention_probs,) + + return outputs + + return forward + + def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): from transformers import BloomModel diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index c7691698b..af49a4d19 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -11,6 +11,7 @@ import colossalai.shardformer.layer as col_nn from ..modeling.bloom import ( BloomPipelineForwards, build_bloom_alibi_tensor_fn, + get_bloom_sequence_parallel_attention_forward, get_bloom_sequence_parallel_forward_fn, get_jit_fused_bloom_attention_forward, get_jit_fused_bloom_gelu_forward, @@ -61,6 +62,15 @@ class BloomPolicy(Policy): use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv + if self.shard_config.enable_sequence_parallelism: + self.append_or_create_method_replacement( + description={ + "forward": get_bloom_sequence_parallel_attention_forward(self.shard_config), + }, + policy=policy, + target_key=BloomAttention, + ) + if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.n_head % self.shard_config.tensor_parallel_size == 0