diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index f737e3b5e..5ca8f9869 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -736,35 +736,24 @@ def get_jit_fused_bloom_attention_forward(): head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, ): + batch_size, q_length, _ = hidden_states.shape fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - - batch_size, q_length, _, _ = query_layer.shape - - query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) if layer_past is not None: - past_key, past_value = layer_past - # concatenate along seq_length dimension: - # - key: [batch_size * self.num_heads, head_dim, kv_length] - # - value: [batch_size * self.num_heads, kv_length, head_dim] - key_layer = torch.cat((past_key, key_layer), dim=2) - value_layer = torch.cat((past_value, value_layer), dim=1) + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) - _, _, kv_length = key_layer.shape - - if use_cache is True: - present = (key_layer, value_layer) - else: - present = None + # reshape qkv for further computations + query_layer = query_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) + key_layer = key_layer.reshape(batch_size * self.num_heads, -1, self.head_dim).transpose(-1, -2) + value_layer = value_layer.reshape(batch_size * self.num_heads, -1, self.head_dim) # [batch_size * num_heads, q_length, kv_length] - # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 - matmul_result = alibi.baddbmm( + attention_scores = alibi.baddbmm( batch1=query_layer, batch2=key_layer, beta=self.beta, @@ -772,15 +761,13 @@ def get_jit_fused_bloom_attention_forward(): ) # change view to [batch_size, num_heads, q_length, kv_length] - attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length) + attn_weights = attention_scores.view(batch_size, self.num_heads, q_length, -1) + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_layer.shape[-1]] + attn_weights = attn_weights + causal_mask - # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] - input_dtype = attention_scores.dtype - # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` - if input_dtype == torch.float16: - attention_scores = attention_scores.to(torch.float) - attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) - attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_layer.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) @@ -789,12 +776,12 @@ def get_jit_fused_bloom_attention_forward(): attention_probs = attention_probs * head_mask # change view [batch_size x num_heads, q_length, kv_length] - attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length) + attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, -1) # matmul: [batch_size * num_heads, q_length, head_dim] context_layer = torch.bmm(attention_probs_reshaped, value_layer) - # change view [batch_size, num_heads, q_length, head_dim] + # change view [batch_size, q_length, num_heads * head_dim] context_layer = self._merge_heads(context_layer) # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 @@ -809,9 +796,9 @@ def get_jit_fused_bloom_attention_forward(): else: output_tensor = self.dense(context_layer) - output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training) + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - outputs = (output_tensor, present) + outputs = (output_tensor, layer_past) if output_attentions: outputs += (attention_probs,) @@ -1072,6 +1059,14 @@ def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig): if output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + # When sequence parallelism done, gather the output tensor in forward and split it in backward + hidden_states = gather_forward_split_backward( + hidden_states, + dim=1, + process_group=shard_config.tensor_parallel_process_group, + fp8_communication=shard_config.fp8_communication, + ) + # Add last hidden state hidden_states = self.ln_f(hidden_states)