mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
Migrated project
This commit is contained in:
4
colossalai/nn/layer/parallel_sequence/__init__.py
Normal file
4
colossalai/nn/layer/parallel_sequence/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from ._operation import RingQK, RingAV
|
||||
from .layers import TransformerSelfAttentionRing
|
||||
|
||||
__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK']
|
169
colossalai/nn/layer/parallel_sequence/_operation.py
Normal file
169
colossalai/nn/layer/parallel_sequence/_operation.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
|
||||
from colossalai.communication import ring_forward
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class RingQK(torch.autograd.Function):
|
||||
"""
|
||||
Calculate QK in a ring-exchange style
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
sub_q,
|
||||
sub_k,
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
sub_seq_length):
|
||||
# save tensor for backward
|
||||
ctx.save_for_backward(sub_q, sub_k)
|
||||
ctx.sub_seq_length = sub_seq_length
|
||||
|
||||
# create local segment of attention score
|
||||
attention_score = torch.empty(
|
||||
batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
|
||||
dtype=sub_q.dtype,
|
||||
device=get_current_device()
|
||||
)
|
||||
|
||||
# compute local QK^T
|
||||
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
start_idx = local_rank * sub_seq_length
|
||||
end_idx = (local_rank + 1) * sub_seq_length
|
||||
attention_score[:, :, start_idx: end_idx] = part_a
|
||||
|
||||
# compute QK^T in ring-all-reduce style
|
||||
for i in range(local_world_size - 1):
|
||||
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
|
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
|
||||
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
|
||||
attention_score[:, :, start_idx:end_idx] = part_a
|
||||
|
||||
return attention_score
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
sub_q, sub_k, = ctx.saved_tensors
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
|
||||
# calculate gradient of sub_k
|
||||
grad_k = torch.matmul(
|
||||
grad_output.transpose(2, 1),
|
||||
sub_q
|
||||
)
|
||||
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
|
||||
grad_k /= local_world_size
|
||||
|
||||
# calculate gradient for sub_q
|
||||
grad_q = torch.zeros_like(sub_q,
|
||||
dtype=sub_q.dtype,
|
||||
device=get_current_device(), )
|
||||
|
||||
# compute with local sub_k
|
||||
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
|
||||
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)
|
||||
|
||||
# compute QK^T in ring-all-reduce style
|
||||
for i in range(local_world_size - 1):
|
||||
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
|
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
|
||||
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
|
||||
|
||||
grad_q /= local_world_size
|
||||
|
||||
return grad_q, grad_k, None, None, None
|
||||
|
||||
|
||||
class RingAV(torch.autograd.Function):
|
||||
"""
|
||||
Calculate AV in a ring-exchange style
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
attention_score,
|
||||
sub_v,
|
||||
batch_size,
|
||||
num_attention_heads,
|
||||
attention_head_size,
|
||||
sub_seq_length):
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
|
||||
|
||||
sub_attention_result = torch.zeros(
|
||||
batch_size * num_attention_heads,
|
||||
sub_seq_length,
|
||||
attention_head_size,
|
||||
device=get_current_device(),
|
||||
dtype=attention_score.dtype)
|
||||
|
||||
# save tensors for backward
|
||||
ctx.save_for_backward(attention_score, sub_v)
|
||||
ctx.sub_seq_length = sub_seq_length
|
||||
|
||||
# compute local AV
|
||||
part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
|
||||
sub_attention_result += part_av
|
||||
|
||||
# compute AV in ring - all - reduce style
|
||||
for i in range(local_world_size - 1):
|
||||
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
|
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
|
||||
|
||||
# compute QK^T
|
||||
part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
|
||||
sub_attention_result += part_av
|
||||
return sub_attention_result
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
|
||||
attention_scores, sub_v = ctx.saved_tensors
|
||||
|
||||
# calculate gradient of v
|
||||
grad_v = torch.matmul(
|
||||
attention_scores.transpose(2, 1),
|
||||
grad_output
|
||||
)
|
||||
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
|
||||
grad_v = grad_v[:, local_start_idx:local_end_idx]
|
||||
grad_v /= local_world_size
|
||||
|
||||
# calculate gradient for attention score
|
||||
grad_attention_score = torch.zeros_like(attention_scores,
|
||||
dtype=grad_output.dtype,
|
||||
device=get_current_device())
|
||||
|
||||
# compute with local sub_k
|
||||
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
|
||||
grad_output,
|
||||
sub_v.transpose(2, 1))
|
||||
|
||||
# compute QK^T in ring-all-reduce style
|
||||
for i in range(local_world_size - 1):
|
||||
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
|
||||
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
|
||||
|
||||
# compute grad_q
|
||||
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
|
||||
grad_output,
|
||||
sub_v.transpose(2, 1))
|
||||
|
||||
return grad_attention_score, grad_v, None, None, None, None
|
15
colossalai/nn/layer/parallel_sequence/_utils.py
Normal file
15
colossalai/nn/layer/parallel_sequence/_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
def _calc_incoming_device_range(i, rank, world_size, sub_seq_length):
|
||||
device_of_incoming_k = (rank - i - 1) % world_size
|
||||
start_idx = sub_seq_length * device_of_incoming_k
|
||||
end_idx = sub_seq_length * (device_of_incoming_k + 1)
|
||||
return start_idx, end_idx
|
||||
|
||||
|
||||
def _calc_current_device_range(rank, sub_seq_length):
|
||||
start_idx = sub_seq_length * rank
|
||||
end_idx = sub_seq_length * (rank + 1)
|
||||
return start_idx, end_idx
|
188
colossalai/nn/layer/parallel_sequence/layers.py
Normal file
188
colossalai/nn/layer/parallel_sequence/layers.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
|
||||
from colossalai.registry import LAYERS
|
||||
|
||||
|
||||
@LAYERS.register_module
|
||||
class TransformerSelfAttentionRing(nn.Module):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
|
||||
:param hidden_size: hidden size
|
||||
:type hidden_size: int
|
||||
:param kv_channels: channels of key/value tensor
|
||||
:type kv_channels: int
|
||||
:param num_attention_heads: number of attention heads
|
||||
:type num_attention_heads: int
|
||||
:param attention_dropout: dropout probability for attention layer
|
||||
:type attention_dropout: float
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
kv_channels,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
projection_size = kv_channels * num_attention_heads
|
||||
self.hidden_size_per_attention_head = projection_size // num_attention_heads
|
||||
|
||||
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
|
||||
# Strided linear layer.
|
||||
self.query_key_value = nn.Linear(
|
||||
hidden_size,
|
||||
3 * projection_size,
|
||||
)
|
||||
|
||||
# coeff = None
|
||||
self.norm_factor = math.sqrt(self.hidden_size)
|
||||
|
||||
# TODO: add apply_query_key_layer_scaling when we have the kernel module
|
||||
# if self.apply_query_key_layer_scaling:
|
||||
# coeff = self.layer_number
|
||||
# self.norm_factor *= coeff
|
||||
|
||||
# TODO: add fused scale mask softmax kernel when we have the kernel module
|
||||
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
|
||||
# self.fp16, self.bf16,
|
||||
# self.attn_mask_type,
|
||||
# masked_softmax_fusion,
|
||||
# attention_mask_func,
|
||||
# self.attention_softmax_in_fp32,
|
||||
# coeff)
|
||||
|
||||
self.attention_dropout = nn.Dropout(attention_dropout)
|
||||
|
||||
# Output.
|
||||
self.dense = nn.Linear(
|
||||
projection_size,
|
||||
hidden_size,
|
||||
bias=True)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
sub_seq_length, batch_size, hidden_size = hidden_states.size()
|
||||
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
|
||||
mixed_x_layer = self.query_key_value(hidden_states)
|
||||
|
||||
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# split into query, key and value
|
||||
last_dim = mixed_x_layer.dim() - 1
|
||||
last_dim_value = mixed_x_layer.size()[-1]
|
||||
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
|
||||
'cannot be divided into query, key and value'
|
||||
partition_size = last_dim_value // 3
|
||||
(query_layer, key_layer, value_layer) = torch.split(
|
||||
mixed_x_layer, partition_size, dim=last_dim)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, num_heads, s, s]
|
||||
# ===================================
|
||||
|
||||
# [b, num_heads, sq, sk]
|
||||
output_size = (query_layer.size(1),
|
||||
query_layer.size(2),
|
||||
query_layer.size(0),
|
||||
key_layer.size(0) * self.world_size)
|
||||
|
||||
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
|
||||
query_layer = query_layer.view(output_size[2],
|
||||
output_size[0] * output_size[1], -1)
|
||||
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
|
||||
key_layer = key_layer.view(key_layer.size(0),
|
||||
output_size[0] * output_size[1], -1)
|
||||
|
||||
# [b, sq, sk]
|
||||
attention_scores = RingQK.apply(
|
||||
# [b * num_heads, sq, hn]
|
||||
query_layer.transpose(0, 1).contiguous(),
|
||||
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn],
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
sub_seq_length
|
||||
)
|
||||
attention_scores /= self.norm_factor
|
||||
|
||||
# change view to [b, num_heads, sq, sk]
|
||||
attention_scores = attention_scores.view(*output_size)
|
||||
attention_scores = attention_scores.unsqueeze(1)
|
||||
|
||||
attention_scores = attention_scores + attention_mask
|
||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||
attention_probs = attention_probs.squeeze(1)
|
||||
|
||||
# This is actually dropping out entire tokens to attend to, which might
|
||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||
# with mpu.get_cuda_rng_tracker().fork():
|
||||
# TODO: check if a rng tracker is needed
|
||||
attention_probs = self.attention_dropout(attention_probs)
|
||||
|
||||
# context layer shape: [b, num_heads, sq, hn]
|
||||
output_size = (value_layer.size(1),
|
||||
value_layer.size(2),
|
||||
query_layer.size(0),
|
||||
value_layer.size(3))
|
||||
#
|
||||
# # change view [sk, b * num_heads, hn]
|
||||
value_layer = value_layer.contiguous().view(value_layer.size(0),
|
||||
output_size[0] * output_size[1], -1)
|
||||
|
||||
# # change view [b * num_heads, sq, sk]
|
||||
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
|
||||
attention_probs.size(2),
|
||||
attention_probs.size(3))
|
||||
|
||||
# matmul: [b*num_heads, sq, hn]
|
||||
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
|
||||
context_layer = RingAV.apply(
|
||||
attention_probs,
|
||||
value_layer.transpose(0, 1).contiguous(),
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
self.hidden_size_per_attention_head,
|
||||
sub_seq_length
|
||||
)
|
||||
|
||||
# # change view [b, num_heads, sq, hn]
|
||||
context_layer = context_layer.view(*output_size)
|
||||
|
||||
# # [b, np, sq, hn] --> [sq, b, np, hn]
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
|
||||
# # [sq, b, np, hn] --> [sq, b, hp]
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (
|
||||
self.hidden_size_per_attention_head * self.num_attention_heads,)
|
||||
context_layer = context_layer.view(*new_context_layer_shape)
|
||||
|
||||
# context_layer = context_layer.transpose(1, 0).contiguous()
|
||||
output = self.dense(context_layer)
|
||||
bias = self.dense.bias
|
||||
|
||||
return output, bias
|
Reference in New Issue
Block a user