adapted for sequence parallel (#163)

This commit is contained in:
Frank Lee
2022-01-20 13:44:51 +08:00
committed by GitHub
parent a2e649da39
commit e2089c5c15
17 changed files with 432 additions and 119 deletions

View File

@@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
class RingQK(torch.autograd.Function):
@@ -17,6 +18,7 @@ class RingQK(torch.autograd.Function):
"""
@staticmethod
@custom_fwd
def forward(ctx,
sub_q,
sub_k,
@@ -54,6 +56,7 @@ class RingQK(torch.autograd.Function):
return attention_score
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
sub_q, sub_k, = ctx.saved_tensors
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
@@ -64,6 +67,7 @@ class RingQK(torch.autograd.Function):
grad_output.transpose(2, 1),
sub_q
)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
@@ -94,6 +98,7 @@ class RingAV(torch.autograd.Function):
"""
@staticmethod
@custom_fwd
def forward(ctx,
attention_score,
sub_v,
@@ -131,6 +136,7 @@ class RingAV(torch.autograd.Function):
return sub_attention_result
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)

View File

@@ -2,15 +2,20 @@
# -*- encoding: utf-8 -*-
import math
import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.context import seed
@LAYERS.register_module
@@ -31,136 +36,144 @@ class TransformerSelfAttentionRing(nn.Module):
def __init__(self,
hidden_size,
kv_channels,
num_attention_heads,
attention_dropout,
attention_mask_func,
layer_number,
apply_query_key_layer_scaling: bool = False,
convert_fp16_to_fp32_in_softmax: bool = False,
attn_mask_type=AttnMaskType.padding,
masked_softmax_fusion=True,
fp16=False,
bf16=False
):
super().__init__()
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_mask_func = attention_mask_func
self.layer_number = layer_number
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attn_mask_type = attn_mask_type
assert self.layer_number > 0
self.attention_dropout = attention_dropout
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = projection_size // num_attention_heads
if self.apply_query_key_layer_scaling:
self.convert_fp16_to_fp32_in_softmax = True
assert self.hidden_size % self.num_attention_heads == 0, \
'hidden size is not divisible by the number of attention heads'
self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# Strided linear layer.
self.query_key_value = nn.Linear(
self.query_key_value = _Linear(
hidden_size,
3 * projection_size,
3 * self.hidden_size,
)
# coeff = None
self.coeff = None
self.norm_factor = math.sqrt(self.hidden_size)
# TODO: add apply_query_key_layer_scaling when we have the kernel module
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
if self.apply_query_key_layer_scaling:
self.coeff = layer_number
self.norm_factor *= self.coeff
# TODO: add fused scale mask softmax kernel when we have the kernel module
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
# self.fp16, self.bf16,
# self.attn_mask_type,
# masked_softmax_fusion,
# attention_mask_func,
# self.attention_softmax_in_fp32,
# coeff)
self.scale_mask_softmax = FusedScaleMaskSoftmax(
fp16, bf16,
self.attn_mask_type,
masked_softmax_fusion,
self.attention_mask_func,
self.convert_fp16_to_fp32_in_softmax,
self.coeff)
self.attention_dropout = nn.Dropout(attention_dropout)
# Output.
self.dense = nn.Linear(
projection_size,
hidden_size,
bias=True)
self.dense = _Linear(hidden_size,
hidden_size,
bias=True,
skip_bias_add=True)
def forward(self, hidden_states, attention_mask):
# hidden_states: [sq, b, h]
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
sub_seq_length, batch_size, hidden_size = hidden_states.size()
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
# Attention heads shape change:
# [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)]
mixed_x_layer = self.query_key_value(hidden_states)
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
# [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# split into query, key and value
last_dim = mixed_x_layer.dim() - 1
last_dim_value = mixed_x_layer.size()[-1]
last_dim_value = mixed_x_layer.size(-1)
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
'cannot be divided into query, key and value'
partition_size = last_dim_value // 3
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer, partition_size, dim=last_dim)
# ===================================
# Raw attention scores. [b, num_heads, s, s]
# ===================================
# [b, num_heads, sq, sk]
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0) * self.world_size)
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
key_layer = key_layer.view(key_layer.size(0),
output_size[0] * output_size[1], -1)
# [b, sq, sk]
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
attention_scores = RingQK.apply(
# [b * num_heads, sq, hn]
query_layer.transpose(0, 1).contiguous(),
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn],
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
batch_size,
self.num_attention_heads,
sub_seq_length
)
attention_scores /= self.norm_factor
# change view to [b, num_heads, sq, sk]
# change view to [batch_size, num_heads, sub_seq_len, seq_len]
attention_scores = attention_scores.view(*output_size)
attention_scores = attention_scores.unsqueeze(1)
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.squeeze(1)
# change shape to [batch_size, num_heads, sub_seq_len, seq_len]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# with mpu.get_cuda_rng_tracker().fork():
# TODO: check if a rng tracker is needed
attention_probs = self.attention_dropout(attention_probs)
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
# context layer shape: [b, num_heads, sq, hn]
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
#
# # change view [sk, b * num_heads, hn]
# change view [sub_seq_len, batch_size * num_heads, head_size]
value_layer = value_layer.contiguous().view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# # change view [b * num_heads, sq, sk]
# # change view [b * num_heads, sub_seq_len, seq_len]
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
attention_probs.size(2),
attention_probs.size(3))
# matmul: [b*num_heads, sq, hn]
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
context_layer = RingAV.apply(
attention_probs,
value_layer.transpose(0, 1).contiguous(),
@@ -170,19 +183,83 @@ class TransformerSelfAttentionRing(nn.Module):
sub_seq_length
)
# # change view [b, num_heads, sq, hn]
# change view [batch_size, num_heads, sub_seq_len, head_size]
context_layer = context_layer.view(*output_size)
# # [b, np, sq, hn] --> [sq, b, np, hn]
# [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# # [sq, b, np, hn] --> [sq, b, hp]
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_attention_head * self.num_attention_heads,)
context_layer = context_layer.view(*new_context_layer_shape)
# context_layer = context_layer.transpose(1, 0).contiguous()
output = self.dense(context_layer)
bias = self.dense.bias
output, bias = self.dense(context_layer)
return output, bias
def __repr__(self):
return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \
f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \
f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \
f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \
f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})'
class _Linear(nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self,
input_size,
output_size,
bias=True,
skip_bias_add=False):
super(_Linear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.weight = Parameter(torch.empty(self.output_size,
self.input_size,
))
nn.init.xavier_normal_(self.weight)
if bias:
self.bias = Parameter(torch.empty(self.output_size))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output = F.linear(input_, self.weight, bias)
if self.skip_bias_add:
return output, self.bias
else:
return output
def __repr__(self):
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'