mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[Shardformer] Merge flash attention branch to pipeline branch (#4362)
* [shardformer] supported flash attention test dependency (#4158) * [shardformer] fix flash attention utils test (#4180) * [shardformer] opt support flash attention (#4163) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] add performance benchmark of shardformer (#4175) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] benchmark fix * [shardformer] benchmark fix * [shardformer] llama support flash attention (#4185) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] llama support flash attention * [shardformer] llama support flash attention * [shardformer] Move the import statement for xformer outside the forward function. * [shardformer] gpt2 support flash attention. (#4191) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] gpt2 support flash attention * [shardformer] bloom support flash attention (#4188) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom suport flash attention * [shardformer] add assert to sequence length * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] bert support flash attention. (#4206) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bert support flash attention * [shardformer] t5 support flash attention. (#4216) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 support flash attention * [shardformer] t5 support flash attention * fix typo * fix typo * fix typo * fix typo * fix typo * fix typo * [shardformer] support 'paddedcausal' type of attention mask in Coloattention. (#4215) * added padded causal attn mask type for ColoAttention * [shardformer]t5 flash attention fix (#4239) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] t5 flash attention fix * [shardformer] update gpt2 to use coloattention. (#4234) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 to use coloattention * [shardformer] update gpt2 * [shardformer] update opt and llama to use coloattention. (#4226) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt to use coloattention * [shardformer]update opt * [shardformer] shardformer support jit fused operator. (#4236) * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] opt support flash attention * [shardformer] move to modeling * [shardformer] move to modeling * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] bloom support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] t5 support jit fused operator * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add roadmap of flash attention * [shardformer] add type hint to 'self' param of forward * [shardformer] merge feature/shardformer-models branch to feature/flash-attention-shardformer branch. (#4290) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] whisper support flash attention (#4301) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] whisper support flash attention * [shardformer] whisper support flash attention * [shardformer]whisper support jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] sam support flash attention (#4316) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] sam support flash attention --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> * [shardformer] merge blip2/chatglm (#4321) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] blip2 support flash attention and jit operator (#4325) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator * [shardformer] blip2 support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] chatglm support flash attention and jit operator (#4330) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator * [shardformer] chatglm support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [shardformer] vit support flash attention and jit operator (#4334) * Feature/vit support (#4182) * [shardformer] added tests * [shardformer] vit test finish and support * fix attention dropout * [shardformer] support SAM (#4231) * 1.support sam 2.add fused qkv for nn.Linear * update utils support set element in list * overtwrite SamVisionAttention foward to use DropoutForParallelInput * remove unused code * [shardformer] support whisper (#4212) * support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme * Feature/chatglm (#4240) * [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commit * [shardformer] added tests * [shardformer] vit test finish and support * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit * [shardformer] support Blip2 (#4243) * support base blip2 * add support for downstream blip2 model * update readme * add forward injection * skip not compatible models test * fix test for gemini and low_level_zero_pugin * [shardformer] vit support flash attention and jit operator * [shardformer] vit support flash attention and jit operator --------- Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com> * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] merge flash attention branch * [pipeline] fix conflict * [pipeline] fix conflict * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * Merge branch 'feature/pipeline' into feature/pipeline * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * activate checks * fix flash attention tests * gemini ignore whisper * fix vit * fix xformers import handle --------- Co-authored-by: Frank Lee <somerlee.9@gmail.com> Co-authored-by: Kun Lin <81014421+klhhhhh@users.noreply.github.com> Co-authored-by: FoolPlayer <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: klhhhhh <1412841649@qq.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
@@ -675,3 +676,223 @@ class BloomPipelineForwards:
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def get_bloom_flash_attention_forward(enabel_jit_fused=False):
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention as me_attention
|
||||
except:
|
||||
raise ImportError("Error: xformers module is not installed. Please install it to use flash attention.")
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
|
||||
fused_qkv = self.query_key_value(hidden_states)
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
batch_size, tgt_len, _ = hidden_states.size()
|
||||
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
|
||||
|
||||
_, kv_length, _, _ = key_layer.size()
|
||||
|
||||
proj_shape = (batch_size, tgt_len, self.num_heads, self.head_dim)
|
||||
query_layer = query_layer.contiguous().view(*proj_shape)
|
||||
key_layer = key_layer.contiguous().view(*proj_shape)
|
||||
value_layer = value_layer.contiguous().view(*proj_shape)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
tgt_len = key_layer.size()[1]
|
||||
|
||||
attention_numerical_mask = torch.zeros((batch_size, self.num_heads, tgt_len, kv_length),
|
||||
dtype=torch.float32,
|
||||
device=query_layer.device,
|
||||
requires_grad=True)
|
||||
attention_numerical_mask = attention_numerical_mask + alibi.view(batch_size, self.num_heads, 1,
|
||||
kv_length) * self.beta
|
||||
attention_numerical_mask = torch.masked_fill(attention_numerical_mask, attention_mask,
|
||||
torch.finfo(torch.float32).min)
|
||||
|
||||
context_layer = me_attention(query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_bias=attention_numerical_mask,
|
||||
scale=self.inv_norm_factor,
|
||||
p=self.attention_dropout.p)
|
||||
context_layer = context_layer.reshape(-1, kv_length, self.hidden_size)
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
# TODO to replace with the bias_dropout_add function in jit
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
outputs = (output_tensor, present, None)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_attention_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomAttention
|
||||
|
||||
def forward(
|
||||
self: BloomAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
|
||||
|
||||
batch_size, q_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
key_layer = key_layer.permute(0, 2, 3, 1).reshape(batch_size * self.num_heads, self.head_dim, q_length)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, head_dim, kv_length]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=2)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
_, _, kv_length = key_layer.shape
|
||||
|
||||
if use_cache is True:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
batch1=query_layer,
|
||||
batch2=key_layer,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
context_layer = torch.bmm(attention_probs_reshaped, value_layer)
|
||||
|
||||
# change view [batch_size, num_heads, q_length, head_dim]
|
||||
context_layer = self._merge_heads(context_layer)
|
||||
|
||||
# aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
slices = self.hidden_size / self.pretraining_tp
|
||||
output_tensor = torch.zeros_like(context_layer)
|
||||
for i in range(self.pretraining_tp):
|
||||
output_tensor = output_tensor + F.linear(
|
||||
context_layer[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
output_tensor = self.dense(context_layer)
|
||||
|
||||
output_tensor = self.dropout_add(output_tensor, residual, self.hidden_dropout, self.training)
|
||||
|
||||
outputs = (output_tensor, present)
|
||||
if output_attentions:
|
||||
outputs += (attention_probs,)
|
||||
|
||||
return outputs
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_mlp_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomMLP
|
||||
|
||||
def forward(self: BloomMLP, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.gelu_impl(self.dense_h_to_4h(hidden_states))
|
||||
|
||||
if self.pretraining_tp > 1 and self.slow_but_exact:
|
||||
intermediate_output = torch.zeros_like(residual)
|
||||
slices = self.dense_4h_to_h.weight.shape[-1] / self.pretraining_tp
|
||||
for i in range(self.pretraining_tp):
|
||||
intermediate_output = intermediate_output + F.linear(
|
||||
hidden_states[:, :, int(i * slices):int((i + 1) * slices)],
|
||||
self.dense_4h_to_h.weight[:, int(i * slices):int((i + 1) * slices)],
|
||||
)
|
||||
else:
|
||||
intermediate_output = self.dense_4h_to_h(hidden_states)
|
||||
output = self.dropout_add(intermediate_output, residual, self.hidden_dropout, self.training)
|
||||
return output
|
||||
|
||||
return forward
|
||||
|
||||
|
||||
def get_jit_fused_bloom_gelu_forward():
|
||||
|
||||
from transformers.models.bloom.modeling_bloom import BloomGelu
|
||||
|
||||
from colossalai.kernel.jit.bias_gelu import GeLUFunction as JitGeLUFunction
|
||||
|
||||
def forward(self: BloomGelu, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = torch.zeros_like(x)
|
||||
if self.training:
|
||||
return JitGeLUFunction.apply(x, bias)
|
||||
else:
|
||||
return self.bloom_gelu_forward(x, bias)
|
||||
|
||||
return forward
|
||||
|
Reference in New Issue
Block a user