[shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests
This commit is contained in:
Hongxin Liu
2024-03-27 11:19:32 +08:00
committed by GitHub
parent 9a3321e9f4
commit 19e1a5cf16
45 changed files with 2543 additions and 1170 deletions

View File

@@ -1,4 +1,5 @@
""" PyTorch ChatGLM model. """
from typing import List, Optional, Tuple
import torch
@@ -9,63 +10,49 @@ from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward():
from colossalai.nn.layer.colo_attention import AttnMaskType, ColoAttention
from .chatglm2_6b.modeling_chatglm import CoreAttention
def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_mask):
pytorch_major_version = int(torch.__version__.split(".")[0])
if pytorch_major_version >= 2:
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, is_causal=True
)
else:
if attention_mask is not None:
attention_mask = ~attention_mask
context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer, key_layer, value_layer, attention_mask
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
attention_mask_type = AttnMaskType.CAUSAL
attn_bias = torch.zeros(
query_layer.shape[0],
1,
query_layer.shape[2],
key_layer.shape[2],
dtype=query_layer.dtype,
device=query_layer.device,
)
temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
attn_bias.masked_fill_(temp_mask.logical_not(), torch.finfo(query_layer.dtype).min)
else:
# Raw attention scores
query_layer = query_layer.permute(1, 0, 2, 3).contiguous()
key_layer = key_layer.permute(1, 0, 2, 3).contiguous()
value_layer = value_layer.permute(1, 0, 2, 3).contiguous()
scale = 1.0 / self.norm_factor
if self.coeff is not None:
scale = scale * self.coeff
flash_attention_mask = None
attn_mask_type = None
if attention_mask is None:
attn_mask_type = AttnMaskType.causal
else:
flash_attention_mask = ~(attention_mask[:, :, -1].squeeze(1).to(torch.bool)).contiguous()
if not torch.all(flash_attention_mask):
attn_mask_type = AttnMaskType.paddedcausal
attention = ColoAttention(
embed_dim=self.hidden_size_per_partition,
num_heads=self.num_attention_heads_per_partition,
dropout=self.attention_dropout.p,
scale=scale,
)
context_layer = attention(
query_layer, key_layer, value_layer, attn_mask=flash_attention_mask, attn_mask_type=attn_mask_type
)
context_layer = context_layer.permute(1, 0, -1).contiguous()
attention_mask_type = AttnMaskType.CUSTOM
if attention_mask is not None:
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
dropout_p = self.attention_dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(
query_layer,
key_layer,
value_layer,
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(*new_context_layer_shape)
return context_layer
return forward
@@ -169,11 +156,17 @@ class ChatGLMPipelineForwards:
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size, device=input_ids.device, dtype=inputs_embeds.dtype
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[attention_mask.new_ones((batch_size, self.pre_seq_len)), attention_mask], dim=-1
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
@@ -200,7 +193,9 @@ class ChatGLMPipelineForwards:
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
@@ -208,7 +203,12 @@ class ChatGLMPipelineForwards:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.encoder.gradient_checkpointing and self.encoder.training:
layer_ret = torch.utils.checkpoint.checkpoint(
layer, hidden_states, attention_mask, rotary_pos_emb, past_key_values[idx], use_cache
layer,
hidden_states,
attention_mask,
rotary_pos_emb,
past_key_values[idx],
use_cache,
)
else:
layer_ret = layer(
@@ -224,7 +224,9 @@ class ChatGLMPipelineForwards:
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -234,7 +236,14 @@ class ChatGLMPipelineForwards:
hidden_states = self.encoder.final_layernorm(hidden_states)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None
v
for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
]
if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
@@ -368,7 +377,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
# Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward(
inputs_embeds, dim=0, process_group=shard_config.tensor_parallel_process_group
inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
@@ -380,7 +391,9 @@ def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
)
hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if not return_dict: