[example] titans for gpt (#2484)

This commit is contained in:
Jiarui Fang
2023-01-16 15:55:41 +08:00
committed by GitHub
parent 7c31706227
commit 3a21485ead
14 changed files with 1754 additions and 5 deletions

View File

@@ -0,0 +1,3 @@
from .embed import vocab_parallel_cross_entropy
from .gpt1d import *
from .pipeline_gpt1d import *

View File

@@ -0,0 +1,599 @@
import torch
import torch.nn.init as init
from torch import Tensor
from torch import distributed as dist
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
from colossalai.nn.layer.parallel_1d.layers import Linear1D_Row
from colossalai.nn.layer.utils import divide
from colossalai.registry import LAYERS, LOSSES, MODELS
from colossalai.utils import get_current_device
class VocabParallelEmbedding(torch.nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes=0,
dtype=torch.float):
super(VocabParallelEmbedding, self).__init__()
self.hidden_size = hidden_size
self.num_tokentypes = num_tokentypes
# Word embeddings (parallel).
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
# self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
# Initialize the token-type embeddings.
# self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
# self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
# Embeddings.
if input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
words_embeddings = self.word_embeddings(input_ids)
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if position_ids is None:
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
# Dropout.
with seed(ParallelMode.TENSOR):
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it',
flush=True)
class VocabParallelEmbedding1D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim, dtype=None, init_method=None):
super(VocabParallelEmbedding1D, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the details for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = gpc.tensor_parallel_size
# Divide the weight matrix along the vocabulary dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
# Allocate weights and initialize.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
init.uniform_(self.weight, -1, 1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
return output
@LOSSES.register_module
class vocab_parallel_cross_entropy(nn.Module):
def __init__(self):
super().__init__()
def forward(self, vocab_parallel_logits, target):
"""Helper function for the cross entropy."""
vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
target = target[..., 1:].contiguous()
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
target.view(-1))
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indices
get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
world_size = gpc.tensor_parallel_size
vocab_start_index, vocab_end_index = get_vocab_range(partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
loss = loss.mean()
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size)
class VocabParallelGPTLMHead1D(ParallelLayer):
"""
Language model head that shares the same parameters with the embedding matrix.
"""
def __init__(self, embed=None, vocab_size=None, dtype=None, embed_dim=None):
super().__init__()
if embed is not None:
self.head = embed
else:
self.head = VocabParallelEmbedding1D(vocab_size, embed_dim, dtype=dtype)
def forward(self, x: Tensor) -> Tensor:
x = reduce_grad(x, ParallelMode.PARALLEL_1D)
x = F.linear(x, self.head.weight)
return x
###################################
class HiddenParallelEmbedding(torch.nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(
self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
dtype=torch.float,
padding_idx: int = 0,
num_tokentypes=0,
):
super(HiddenParallelEmbedding, self).__init__()
self.hidden_size = hidden_size
self.num_tokentypes = num_tokentypes
# Word embeddings (parallel).
self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
self._word_embeddings_key = 'word_embeddings'
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
# Initialize the position embeddings.
# self.init_method(self.position_embeddings.weight)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
# self.init_method(self.tokentype_embeddings.weight)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
def zero_parameters(self):
"""Zero out all parameters in embedding."""
self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.shared = True
self.position_embeddings.weight.data.fill_(0)
self.position_embeddings.weight.shared = True
if self.num_tokentypes > 0:
self.tokentype_embeddings.weight.data.fill_(0)
self.tokentype_embeddings.weight.shared = True
def add_tokentype_embeddings(self, num_tokentypes):
"""Add token-type embedding. This function is provided so we can add
token-type embeddings in case the pretrained model does not have it.
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
# self.init_method(self.tokentype_embeddings.weight)
def forward(self, input_ids, position_ids=None, tokentype_ids=None):
if input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
words_embeddings = self.word_embeddings(input_ids)
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if position_ids is None:
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
# Dropout.
with seed(ParallelMode.TENSOR):
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
return state_dict_
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
# Word embedding.
if self._word_embeddings_key in state_dict:
state_dict_ = state_dict[self._word_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
if self._position_embeddings_key in state_dict:
state_dict_ = state_dict[self._position_embeddings_key]
else:
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
if self.num_tokentypes > 0:
state_dict_ = {}
if self._tokentype_embeddings_key in state_dict:
state_dict_ = state_dict[self._tokentype_embeddings_key]
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it',
flush=True)
class HiddenParallelEmbedding1D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(self, num_embeddings, embedding_dim, dtype=torch.float, padding_idx: int = None, init_method=None):
super(HiddenParallelEmbedding1D, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
embed_dim_per_partition = divide(embedding_dim, gpc.tensor_parallel_size)
# Set the details for compatibility.
self.padding_idx = padding_idx
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Allocate weights and initialize.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
init.uniform_(self.weight, -1, 1)
def forward(self, input_):
# Get the embeddings.
output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
# Reduce across all the model parallel GPUs.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
return output
@LAYERS.register_module
class HiddenParallelGPTLMHead1D(ParallelLayer):
"""
Language model head that shares the same parameters with the embedding matrix.
"""
def __init__(
self,
embed=None,
embed_dim=None,
vocab_size=None,
dtype=None,
):
super().__init__()
if embed is not None:
self.head = embed
self.synced_embed = True
else:
# self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
# (hidden_size/q, vocab_size)
self.synced_embed = False
self.head = Linear1D_Row(in_features=embed_dim,
out_features=vocab_size,
bias=False,
dtype=dtype,
parallel_input=False)
def forward(self, x: Tensor) -> Tensor:
if self.synced_embed:
x = F.linear(x, self.head.weight)
else:
x = self.head(x)
return x

View File

@@ -0,0 +1,349 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
from torch import Tensor
from torch import nn as nn
from colossalai import kernel
from colossalai import nn as col_nn
from colossalai.core import global_context as gpc
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.nn.layer import Linear1D_Col, Linear1D_Row
from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.nn.layer.utils import ACT2FN, divide
from colossalai.utils import checkpoint
from colossalai.utils.activation_checkpoint import checkpoint
__all__ = [
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
]
class GPTMLP1D(ParallelLayer):
def __init__(
self,
in_features: int,
mlp_ratio: int,
act_func: str = 'gelu',
dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
skip_bias_add: bool = False,
):
super().__init__()
self.in_features = in_features
self.mlp_ratio = mlp_ratio
self.checkpoint = checkpoint
self.skip_bias_add = skip_bias_add
self.act = ACT2FN[act_func]
skip_dense_1_add_bias = False
# Project to mlp_ratio * h.
self.dense_1 = Linear1D_Col(
self.in_features,
int(self.mlp_ratio * self.in_features),
dtype=dtype,
gather_output=False,
skip_bias_add=skip_dense_1_add_bias,
)
# Project back to h.
self.dense_2 = Linear1D_Row(
int(self.mlp_ratio * self.in_features),
self.in_features,
dtype=dtype,
parallel_input=True,
)
self.dropout = col_nn.Dropout(dropout_prob)
def _forward(self, hidden_states: Tensor) -> Tensor:
intermediate_output = self.dense_1(hidden_states)
intermediate_output = self.act(intermediate_output)
output = self.dense_2(intermediate_output)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor) -> Tensor:
return checkpoint(self._forward, False, hidden_states)
def forward(self, hidden_states: Tensor) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states)
else:
return self._forward(hidden_states)
class GenericGPTSelfAttention1D(ParallelLayer):
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024,
):
super().__init__()
self.hidden_size = hidden_size
self.attention_head_size = divide(hidden_size, num_attention_heads)
self.num_attention_heads_per_partition = divide(num_attention_heads, gpc.tensor_parallel_size)
self.hidden_size_per_partition = divide(hidden_size, gpc.tensor_parallel_size)
self.checkpoint = checkpoint
self.query_key_value = Linear1D_Col(
hidden_size,
3 * hidden_size,
dtype=dtype,
)
self.attention_dropout = col_nn.Dropout(attention_dropout_prob)
self.dense = Linear1D_Row(
hidden_size,
hidden_size,
dtype=dtype,
parallel_input=True,
)
self.dropout = col_nn.Dropout(hidden_dropout_prob)
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
raise NotImplementedError
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
query_key_value = self.query_key_value(hidden_states)
new_qkv_shape = query_key_value.shape[:-1] + \
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
query_key_value = query_key_value.view(new_qkv_shape)
query_key_value = query_key_value.permute((0, 2, 1, 3))
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = self.softmax_forward(attention_scores, attention_mask, query_layer, key_layer)
attention_scores = attention_scores.type(value_layer.dtype)
attention_probs = self.attention_dropout(attention_scores)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.transpose(1, 2)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
output = self.dropout(output)
return output
def _checkpoint_forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
return checkpoint(self._forward, False, hidden_states, attention_mask)
def forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
if self.checkpoint:
return self._checkpoint_forward(hidden_states, attention_mask)
else:
return self._forward(hidden_states, attention_mask)
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024):
super().__init__(hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings)
self.softmax = nn.Softmax(dim=-1)
max_positions = max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# causal mask
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
if attention_mask is not None:
# Apply the attention mask
attention_scores = attention_scores + attention_mask
attention_scores = self.softmax(attention_scores)
return attention_scores
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
attention_dropout_prob: float,
hidden_dropout_prob: float,
dtype=None,
checkpoint: bool = False,
max_position_embeddings=1024):
super().__init__(hidden_size,
num_attention_heads,
attention_dropout_prob,
hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings)
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
input_in_bf16=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
mask_func=None,
softmax_in_fp32=True,
scale=math.sqrt(self.attention_head_size))
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
return self.softmax(attention_scores, attention_mask)
class GenericGPTTransformerLayer1D(ParallelLayer):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4.0,
attention_dropout_prob: float = 0.,
hidden_dropout_prob: float = 0.,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
attention=None,
layer_norm=None):
super().__init__()
self.checkpoint = checkpoint
self.dtype = dtype
self.norm1 = layer_norm(hidden_size, eps=layer_norm_epsilon)
self.apply_post_layer_norm = apply_post_layer_norm
self.attention = attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
max_position_embeddings=max_position_embeddings,
checkpoint=False,
)
self.norm2 = layer_norm(hidden_size, eps=layer_norm_epsilon)
self.mlp = GPTMLP1D(
in_features=hidden_size,
dropout_prob=hidden_dropout_prob,
act_func=act_func,
mlp_ratio=mlp_ratio,
dtype=dtype,
checkpoint=False,
)
def _forward(self, hidden_states, attention_mask) -> Tensor:
if not self.apply_post_layer_norm:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
if self.apply_post_layer_norm:
residual = hidden_states
attention_output = self.attention(hidden_states, attention_mask)
hidden_states = residual + attention_output
if not self.apply_post_layer_norm:
residual = hidden_states
hidden_states = self.norm2(hidden_states)
if self.apply_post_layer_norm:
residual = hidden_states
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = residual + feed_forward_hidden_states
output = (hidden_states, attention_mask)
return output
def forward(self, hidden_states, attention_mask):
if self.checkpoint:
return checkpoint(self._forward, False, hidden_states, attention_mask)
else:
return self._forward(hidden_states, attention_mask)
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False):
attention = GPTSelfAttention1D
layer_norm = nn.LayerNorm
super().__init__(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm)
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
def __init__(self,
hidden_size: int,
num_attention_heads: int,
act_func: str = 'gelu',
mlp_ratio: float = 4,
attention_dropout_prob: float = 0,
hidden_dropout_prob: float = 0,
dtype=None,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 0.00001,
apply_post_layer_norm: bool = False):
attention = FusedGPTSelfAttention1D
layer_norm = kernel.LayerNorm
super().__init__(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attention_dropout_prob,
hidden_dropout_prob=hidden_dropout_prob,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm,
attention=attention,
layer_norm=layer_norm)

View File

@@ -0,0 +1,322 @@
import inspect
# import model_zoo.gpt.gpt as col_gpt
import titans.model.gpt.gpt as col_gpt
import torch
import torch.nn as nn
from colossalai import kernel
from colossalai import nn as col_nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.pipeline.utils import partition_uniform
from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabParallelEmbedding, VocabParallelGPTLMHead1D
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
__all__ = [
'GPT2_small_pipeline_1D',
'GPT2_exlarge_pipeline_1D',
'GPT3_pipeline_1D',
'GPT2_exlarge_pipeline_hybrid',
'GPT2_small_pipeline_hybrid',
'GPT3_pipeline_hybrid',
]
class GenericPipelineGPT(nn.Module):
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
super().__init__()
self.embedding = embedding
self.blocks = blocks
self.norm = norm
self.head = head
assert blocks is not None
if norm is not None or head is not None:
assert norm is not None and head is not None
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
if self.embedding is not None:
hidden_states = self.embedding(input_ids=input_ids)
batch_size = hidden_states.shape[0]
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
for block in self.blocks:
hidden_states, attention_mask = block(hidden_states, attention_mask)
if self.norm is not None:
hidden_states = self.head(self.norm(hidden_states))
return hidden_states
class PipelineGPT1D(GenericPipelineGPT):
def __init__(self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
embed_drop_rate: float = 0.,
act_func: str = 'gelu',
mlp_ratio: int = 4.0,
attn_drop_rate: float = 0.,
drop_rate: float = 0.,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
first: bool = False,
last: bool = False,
embed_split_hidden=False):
embedding = None
norm = None
head = None
embed_cls = VocabParallelEmbedding
head_cls = VocabParallelGPTLMHead1D
if embed_split_hidden:
embed_cls = HiddenParallelEmbedding
head_cls = HiddenParallelGPTLMHead1D
if first:
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
blocks = nn.ModuleList([
GPTTransformerLayer1D(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attn_drop_rate,
hidden_dropout_prob=drop_rate,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
])
if last:
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
class FusedPipelineGPT1D(GenericPipelineGPT):
def __init__(self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
embed_drop_rate: float = 0.,
act_func: str = 'gelu',
mlp_ratio: int = 4.0,
attn_drop_rate: float = 0.,
drop_rate: float = 0.,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
first: bool = False,
last: bool = False,
embed_split_hidden=False):
embedding = None
norm = None
head = None
embed_cls = VocabParallelEmbedding
head_cls = VocabParallelGPTLMHead1D
if embed_split_hidden:
embed_cls = HiddenParallelEmbedding
head_cls = HiddenParallelGPTLMHead1D
if first:
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
blocks = nn.ModuleList([
FusedGPTTransformerLayer1D(hidden_size,
num_attention_heads,
act_func=act_func,
mlp_ratio=mlp_ratio,
attention_dropout_prob=attn_drop_rate,
hidden_dropout_prob=drop_rate,
dtype=dtype,
checkpoint=checkpoint,
max_position_embeddings=max_position_embeddings,
layer_norm_epsilon=layer_norm_epsilon,
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
])
if last:
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
if self.embedding is not None:
hidden_states = self.embedding(input_ids=input_ids)
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
for block in self.blocks:
hidden_states, attention_mask = block(hidden_states, attention_mask)
if self.norm is not None:
hidden_states = self.head(self.norm(hidden_states))
return hidden_states
class PipelineGPTHybrid(GenericPipelineGPT):
def __init__(self,
num_layers: int = 12,
hidden_size: int = 768,
num_attention_heads: int = 12,
vocab_size: int = 50304,
embed_drop_rate: float = 0.,
act_func: str = 'gelu',
mlp_ratio: int = 4,
attn_drop_rate: float = 0.,
drop_rate: float = 0.,
dtype: torch.dtype = torch.float,
checkpoint: bool = False,
max_position_embeddings: int = 1024,
layer_norm_epsilon: float = 1e-5,
apply_post_layer_norm: bool = False,
first: bool = False,
last: bool = False,
embed_split_hidden=False):
embedding = None
norm = None
head = None
if first:
embedding = col_gpt.GPTEmbedding(hidden_size,
vocab_size,
max_position_embeddings,
dropout=embed_drop_rate,
dtype=dtype)
blocks = nn.ModuleList([
col_gpt.GPTBlock(hidden_size,
num_attention_heads,
mlp_ratio=mlp_ratio,
attention_dropout=attn_drop_rate,
dropout=drop_rate,
dtype=dtype,
checkpoint=checkpoint,
activation=nn.functional.gelu) for _ in range(num_layers)
])
if last:
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
# hidden_size=hidden_size,
# dtype=dtype,
# bias=False)
head = col_nn.Classifier(hidden_size, vocab_size, dtype=dtype, bias=False)
super().__init__(embedding=embedding, blocks=blocks, norm=norm, head=head)
def _filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
logger = get_dist_logger()
if gpc.is_initialized(ParallelMode.PIPELINE):
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
else:
pipeline_size = 1
pipeline_rank = 0
rank = gpc.get_global_rank()
if pipeline_size > 1:
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
else:
wrapper = None
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
models = []
for start, end in parts:
kwargs['num_layers'] = end - start
kwargs['first'] = start == 0
kwargs['last'] = end == num_layers
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
if wrapper is not None:
if start == 0:
wrapper.register_module(chunk.embedding.word_embeddings)
elif end == num_layers:
wrapper.register_module(chunk.head)
models.append(chunk)
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
numel = 0
for _, param in model.named_parameters(recurse=True):
numel += param.numel()
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
return model
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
model = FusedPipelineGPT1D if fused else PipelineGPT1D
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
cfg = dict(hidden_size=768,
num_attention_heads=12,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden)
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
cfg = dict(hidden_size=1600,
num_attention_heads=32,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden)
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
cfg = dict(hidden_size=12288,
num_attention_heads=96,
checkpoint=checkpoint,
max_position_embeddings=2048,
dtype=dtype,
embed_split_hidden=embed_split_hidden)
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(hidden_size=1600,
num_attention_heads=32,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden)
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(hidden_size=768,
num_attention_heads=12,
checkpoint=checkpoint,
dtype=dtype,
embed_split_hidden=embed_split_hidden)
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
cfg = dict(hidden_size=12288,
num_attention_heads=96,
checkpoint=checkpoint,
max_position_embeddings=2048,
dtype=dtype,
embed_split_hidden=embed_split_hidden)
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)