[tutorial] edited hands-on practices (#1899)

* Add handson to ColossalAI.

* Change names of handsons and edit sequence parallel example.

* Edit wrong folder name

* resolve conflict

* delete readme
This commit is contained in:
BoxiangW
2022-11-11 04:08:17 -05:00
committed by GitHub
parent d9bf83e084
commit ca6e75bc28
121 changed files with 20464 additions and 0 deletions

View File

@@ -0,0 +1,4 @@
from .embedding import VocabEmbedding, Embedding
from .bert_layer import BertLayer
from .head import BertDualHead
from .preprocess import PreProcessor

View File

@@ -0,0 +1,118 @@
import torch
import torch.nn as nn
from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing
from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
from colossalai.kernel.cuda_native import LayerNorm
from .mlp import TransformerMLP
from .dropout import get_bias_dropout_add
def attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores
class BertLayer(nn.Module):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
"""
def __init__(self,
layer_number,
hidden_size,
num_attention_heads,
attention_dropout,
mlp_ratio,
hidden_dropout,
is_naive_fp16,
apply_residual_connection_post_layernorm=False,
fp32_residual_connection=False,
bias_dropout_fusion: bool = True,
convert_fp16_to_fp32_in_softmax: bool = False):
super().__init__()
self.layer_number = layer_number
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.fp32_residual_connection = fp32_residual_connection
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size)
# Self attention.
self.self_attention = TransformerSelfAttentionRing(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=attention_dropout,
attention_mask_func=attention_mask_func,
layer_number=layer_number,
apply_query_key_layer_scaling=True,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
fp16=is_naive_fp16
)
self.hidden_dropout = hidden_dropout
self.bias_dropout_fusion = bias_dropout_fusion
# Layernorm on the attention output
self.post_attention_layernorm = LayerNorm(hidden_size)
self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio)
def forward(self, hidden_states, attention_mask):
# hidden_states: [batch_size, sub_seq_len, hidden_size]
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask)
# Residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if self.bias_dropout_fusion:
if self.training:
bias_dropout_add_func = bias_dropout_add_fused_train
else:
bias_dropout_add_func = bias_dropout_add_fused_inference
else:
bias_dropout_add_func = get_bias_dropout_add(self.training)
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
layernorm_input = bias_dropout_add_func(
attention_output,
attention_bias.expand_as(residual),
residual,
self.hidden_dropout)
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_bias = self.mlp(layernorm_output)
# Second residual connection.
if self.apply_residual_connection_post_layernorm:
residual = layernorm_output
else:
residual = layernorm_input
# re-enable torch grad to enable fused optimization.
with torch.enable_grad():
output = bias_dropout_add_func(
mlp_output,
mlp_bias.expand_as(residual),
residual,
self.hidden_dropout)
return output

View File

@@ -0,0 +1,13 @@
import torch
def bias_dropout_add(x, bias, residual, prob, training):
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
out = residual + out
return out
def get_bias_dropout_add(training):
def _bias_dropout_add(x, bias, residual, prob):
return bias_dropout_add(x, bias, residual, prob, training)
return _bias_dropout_add

View File

@@ -0,0 +1,96 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class VocabEmbedding(torch.nn.Module):
def __init__(self, num_embeddings, embedding_dim):
super(VocabEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Allocate weights and initialize.
self.weight = nn.Parameter(torch.empty(
self.num_embeddings, self.embedding_dim))
init.xavier_uniform_(self.weight)
def forward(self, hidden_state):
output = F.embedding(hidden_state, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
return output
def __repr__(self):
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
f'embedding_dim={self.embedding_dim})'
class Embedding(nn.Module):
"""Language model embeddings.
Arguments:
hidden_size: hidden size
vocab_size: vocabulary size
max_sequence_length: maximum size of sequence. This
is used for positional embedding
embedding_dropout_prob: dropout probability for embeddings
init_method: weight initialization method
num_tokentypes: size of the token-type embeddings. 0 value
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes):
super(Embedding, self).__init__()
self.hidden_size = hidden_size
self.num_tokentypes = num_tokentypes
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
max_sequence_length, self.hidden_size)
# Token type embedding.
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
self.hidden_size)
else:
self.tokentype_embeddings = None
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
@property
def word_embedding_weight(self):
return self.word_embeddings.weight
def forward(self, input_ids, position_ids, tokentype_ids=None):
# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None and self.tokentype_embeddings is not None:
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
# Dropout.
embeddings = self.embedding_dropout(embeddings)
return embeddings

View File

@@ -0,0 +1,78 @@
import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pooler import Pooler
from .linear import Linear
from .embedding import VocabEmbedding
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.kernel import LayerNorm
from loss_func.cross_entropy import vocab_cross_entropy
class BertLMHead(nn.Module):
"""Masked LM head for Bert
Arguments:
hidden_size: hidden size
init_method: init method for weight initialization
layernorm_epsilon: tolerance for layer norm divisions
"""
def __init__(self,
vocab_size,
hidden_size,
):
super(BertLMHead, self).__init__()
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
self.dense = Linear(hidden_size, hidden_size)
self.layernorm = LayerNorm(hidden_size)
self.gelu = torch.nn.functional.gelu
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
hidden_states = self.dense(hidden_states)
hidden_states = self.gelu(hidden_states)
hidden_states = self.layernorm(hidden_states)
output = F.linear(hidden_states, word_embeddings_weight, self.bias)
lm_loss = vocab_cross_entropy(output, lm_labels)
return lm_loss
class BertBinaryHead(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.pooler = Pooler(hidden_size)
self.dense = Linear(hidden_size, 2)
def forward(self, hidden_states):
if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0:
output = self.pooler(hidden_states)
output = self.dense(output)
else:
output = None
return output
class BertDualHead(nn.Module):
def __init__(self, hidden_size, vocab_size, add_binary_head):
super().__init__()
self.lm_head = BertLMHead(vocab_size, hidden_size)
self.add_binary_head = add_binary_head
if add_binary_head:
self.binary_head = BertBinaryHead(hidden_size)
else:
self.binary_head = None
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
if self.add_binary_head:
binary_output = self.binary_head(hidden_states)
else:
binary_output = None
lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels)
return lm_loss, binary_output

View File

@@ -0,0 +1,12 @@
import torch
import math
def init_normal(tensor, sigma):
"""Init method based on N(0, sigma)."""
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
def output_init_normal(tensor, sigma, num_layers):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std = sigma / math.sqrt(2.0 * num_layers)
torch.nn.init.normal_(tensor, mean=0.0, std=std)

View File

@@ -0,0 +1,63 @@
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
import torch.nn.init as init
class Linear(nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self,
input_size,
output_size,
bias=True,
skip_bias_add=False):
super(Linear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.weight = Parameter(torch.empty(self.output_size,
self.input_size,
))
init.normal_(self.weight)
if bias:
self.bias = Parameter(torch.empty(self.output_size))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output = F.linear(input_, self.weight, bias)
if self.skip_bias_add:
return output, self.bias
else:
return output
def __repr__(self):
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'

View File

@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .linear import Linear
from colossalai.kernel.jit import bias_gelu_impl
class TransformerMLP(nn.Module):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
"""
def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
super(TransformerMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = Linear(
hidden_size,
int(hidden_size*mlp_ratio),
skip_bias_add=True)
self.bias_gelu_fusion = fuse_gelu
self.activation_func = F.gelu
# Project back to h.
self.dense_4h_to_h = Linear(
int(hidden_size*mlp_ratio),
hidden_size,
skip_bias_add=True)
def forward(self, hidden_states):
# hidden states should be in the shape of [s, b, h]
# it will be projects into [s, b, 4h]
# and projected back to [s, b, h]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
if self.bias_gelu_fusion:
intermediate_parallel = \
bias_gelu_impl(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = \
self.activation_func(intermediate_parallel + bias_parallel)
# [s, b, h]
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
return output, output_bias

View File

@@ -0,0 +1,28 @@
import torch
import torch.nn as nn
from .linear import Linear
class Pooler(nn.Module):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def __init__(self, hidden_size):
super(Pooler, self).__init__()
self.dense = Linear(hidden_size, hidden_size)
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [b, s, h]
# sequence_index: index of the token to pool.
pooled = hidden_states[:, sequence_index, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
return pooled

View File

@@ -0,0 +1,58 @@
from colossalai.context.parallel_mode import ParallelMode
import torch
import torch.nn as nn
from colossalai.core import global_context as gpc
class PreProcessor(nn.Module):
def __init__(self, sub_seq_length):
super().__init__()
self.sub_seq_length = sub_seq_length
def bert_position_ids(self, token_ids):
# Create position ids
seq_length = token_ids.size(1)
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
position_ids = torch.arange(seq_length*local_rank,
seq_length * (local_rank+1),
dtype=torch.long,
device=token_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
return position_ids
def bert_extended_attention_mask(self, attention_mask):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
start_index = local_rank * self.sub_seq_length
end_index = (local_rank + 1) * self.sub_seq_length
# We create a 3D attention mask from a 2D tensor mask.
# [b, 1, s]
attention_mask_b1s = attention_mask.unsqueeze(1)
# [b, s, 1]
attention_mask_bs1 = attention_mask.unsqueeze(2)
# [b, s/D, s]
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
attention_mask_bss = attention_mask_bss[:, start_index:end_index, :]
# [b, 1, s/D, s]
extended_attention_mask = attention_mask_bss.unsqueeze(1)
# Convert attention mask to binary:
extended_attention_mask = (extended_attention_mask < 0.5)
return extended_attention_mask
def forward(self, input_ids=None, attention_mask=None):
if attention_mask is not None:
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
else:
extended_attention_mask = None
if input_ids is not None:
position_ids = self.bert_position_ids(input_ids)
else:
position_ids = None
return position_ids, extended_attention_mask