mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-27 03:26:41 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
2
examples/tutorial/sequence_parallel/model/__init__.py
Normal file
2
examples/tutorial/sequence_parallel/model/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
|
||||
|
||||
282
examples/tutorial/sequence_parallel/model/bert.py
Normal file
282
examples/tutorial/sequence_parallel/model/bert.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import inspect
|
||||
from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding
|
||||
from .layers.init_method import init_normal, output_init_normal
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
|
||||
class BertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
|
||||
if not add_binary_head:
|
||||
num_tokentypes = 0
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
bert_layer = BertLayer(layer_number=i+1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0),
|
||||
add_binary_head=add_binary_head)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
init_normal(tensor, sigma=self.init_std)
|
||||
|
||||
def _output_init_normal(self, tensor):
|
||||
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize embedding
|
||||
self._init_normal(self.embedding.word_embedding_weight)
|
||||
self._init_normal(self.embedding.position_embeddings.weight)
|
||||
if self.embedding.tokentype_embeddings:
|
||||
self._init_normal(self.embedding.tokentype_embeddings.weight)
|
||||
|
||||
# initialize bert layer
|
||||
for layer in self.bert_layers:
|
||||
# initialize self attention
|
||||
self._init_normal(layer.self_attention.query_key_value.weight)
|
||||
self._output_init_normal(layer.self_attention.dense.weight)
|
||||
self._init_normal(layer.mlp.dense_h_to_4h.weight)
|
||||
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
|
||||
|
||||
# initializer head
|
||||
self._init_normal(self.head.lm_head.dense.weight)
|
||||
if self.head.binary_head is not None:
|
||||
self._init_normal(self.head.binary_head.pooler.dense.weight)
|
||||
self._init_normal(self.head.binary_head.dense.weight)
|
||||
|
||||
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
|
||||
# inputs of the forward function
|
||||
# input_ids: [batch_size, sub_seq_len]
|
||||
# attention_mask: [batch_size, seq_len]
|
||||
# tokentype_ids: [batch_size, sub_seq_len]
|
||||
# outputs of preprocessor
|
||||
# pos_ids: [batch_size, sub_seq_len]
|
||||
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
|
||||
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
|
||||
|
||||
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
|
||||
|
||||
# hidden_states shape change:
|
||||
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
for idx, layer in enumerate(self.bert_layers):
|
||||
hidden_states = layer(hidden_states, attention_masks)
|
||||
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.layer_norm(hidden_states)
|
||||
|
||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||
# word_embedding: [vocab_size, hidden_size]
|
||||
return self.head(output, self.embedding.word_embedding_weight, lm_labels)
|
||||
|
||||
|
||||
class PipelineBertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
|
||||
if not add_binary_head:
|
||||
num_tokentypes = 0
|
||||
|
||||
self.first_stage = first_stage
|
||||
self.last_stage = last_stage
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
|
||||
if self.first_stage:
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
|
||||
# transformer layers
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
if start_idx is None and end_idx is None:
|
||||
start_idx = 0
|
||||
end_idx = num_layers
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
bert_layer = BertLayer(layer_number=i+1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
if self.last_stage:
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, hidden_size)
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size, vocab_size,
|
||||
add_binary_head=add_binary_head)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
init_normal(tensor, sigma=self.init_std)
|
||||
|
||||
def _output_init_normal(self, tensor):
|
||||
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize embedding
|
||||
if self.first_stage:
|
||||
self._init_normal(self.embedding.word_embedding_weight)
|
||||
self._init_normal(self.embedding.position_embeddings.weight)
|
||||
if self.embedding.tokentype_embeddings:
|
||||
self._init_normal(self.embedding.tokentype_embeddings.weight)
|
||||
|
||||
# initialize bert layer
|
||||
for layer in self.bert_layers:
|
||||
# initialize self attention
|
||||
self._init_normal(layer.self_attention.query_key_value.weight)
|
||||
self._output_init_normal(layer.self_attention.dense.weight)
|
||||
self._init_normal(layer.mlp.dense_h_to_4h.weight)
|
||||
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
|
||||
|
||||
# initializer head
|
||||
if self.last_stage:
|
||||
self._init_normal(self.head.lm_head.dense.weight)
|
||||
if self.head.binary_head is not None:
|
||||
self._init_normal(self.head.binary_head.pooler.dense.weight)
|
||||
self._init_normal(self.head.binary_head.dense.weight)
|
||||
|
||||
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
|
||||
# inputs of the forward function
|
||||
# input_ids: [batch_size, sub_seq_len]
|
||||
# attention_mask: [batch_size, seq_len]
|
||||
# tokentype_ids: [batch_size, sub_seq_len]
|
||||
# outputs of preprocessor
|
||||
# pos_ids: [batch_size, sub_seq_len]
|
||||
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
|
||||
if self.first_stage:
|
||||
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
|
||||
else:
|
||||
_, attention_masks = self.preprocessor(None, attention_masks)
|
||||
|
||||
if self.first_stage:
|
||||
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
|
||||
# hidden_states shape change:
|
||||
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||
for idx, layer in enumerate(self.bert_layers):
|
||||
hidden_states = layer(hidden_states, attention_masks)
|
||||
|
||||
if self.last_stage:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.layer_norm(hidden_states)
|
||||
output = self.head(output, self.word_embeddings.weight, lm_labels)
|
||||
else:
|
||||
output = hidden_states
|
||||
|
||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||
# word_embedding: [vocab_size, hidden_size]
|
||||
return output
|
||||
|
||||
|
||||
def _filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
rank = gpc.get_global_rank()
|
||||
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
elif end == num_layers:
|
||||
wrapper.register_module(chunk.word_embeddings)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
return model
|
||||
@@ -0,0 +1,4 @@
|
||||
from .embedding import VocabEmbedding, Embedding
|
||||
from .bert_layer import BertLayer
|
||||
from .head import BertDualHead
|
||||
from .preprocess import PreProcessor
|
||||
118
examples/tutorial/sequence_parallel/model/layers/bert_layer.py
Normal file
118
examples/tutorial/sequence_parallel/model/layers/bert_layer.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.nn.layer.parallel_sequence import TransformerSelfAttentionRing
|
||||
from colossalai.kernel.jit import bias_dropout_add_fused_train, bias_dropout_add_fused_inference
|
||||
from colossalai.kernel.cuda_native import LayerNorm
|
||||
from .mlp import TransformerMLP
|
||||
from .dropout import get_bias_dropout_add
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, attention_mask):
|
||||
attention_scores.masked_fill_(attention_mask, -10000.0)
|
||||
return attention_scores
|
||||
|
||||
|
||||
class BertLayer(nn.Module):
|
||||
"""A single transformer layer.
|
||||
Transformer layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
layer_number,
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout,
|
||||
mlp_ratio,
|
||||
hidden_dropout,
|
||||
is_naive_fp16,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
fp32_residual_connection=False,
|
||||
bias_dropout_fusion: bool = True,
|
||||
convert_fp16_to_fp32_in_softmax: bool = False):
|
||||
super().__init__()
|
||||
self.layer_number = layer_number
|
||||
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = LayerNorm(hidden_size)
|
||||
|
||||
# Self attention.
|
||||
self.self_attention = TransformerSelfAttentionRing(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=attention_dropout,
|
||||
attention_mask_func=attention_mask_func,
|
||||
layer_number=layer_number,
|
||||
apply_query_key_layer_scaling=True,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
fp16=is_naive_fp16
|
||||
)
|
||||
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = LayerNorm(hidden_size)
|
||||
|
||||
self.mlp = TransformerMLP(hidden_size=hidden_size, mlp_ratio=mlp_ratio)
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
# hidden_states: [batch_size, sub_seq_len, hidden_size]
|
||||
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
|
||||
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
layernorm_output = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self attention.
|
||||
attention_output, attention_bias = self.self_attention(layernorm_output, attention_mask)
|
||||
|
||||
# Residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = hidden_states
|
||||
|
||||
# jit scripting for a nn.module (with dropout) is not
|
||||
# trigerring the fusion kernel. For now, we use two
|
||||
# different nn.functional routines to account for varying
|
||||
# dropout semantics during training and inference phases.
|
||||
if self.bias_dropout_fusion:
|
||||
if self.training:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_train
|
||||
else:
|
||||
bias_dropout_add_func = bias_dropout_add_fused_inference
|
||||
else:
|
||||
bias_dropout_add_func = get_bias_dropout_add(self.training)
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
layernorm_input = bias_dropout_add_func(
|
||||
attention_output,
|
||||
attention_bias.expand_as(residual),
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
# Layer norm post the self attention.
|
||||
layernorm_output = self.post_attention_layernorm(layernorm_input)
|
||||
|
||||
# MLP.
|
||||
mlp_output, mlp_bias = self.mlp(layernorm_output)
|
||||
|
||||
# Second residual connection.
|
||||
if self.apply_residual_connection_post_layernorm:
|
||||
residual = layernorm_output
|
||||
else:
|
||||
residual = layernorm_input
|
||||
|
||||
# re-enable torch grad to enable fused optimization.
|
||||
with torch.enable_grad():
|
||||
output = bias_dropout_add_func(
|
||||
mlp_output,
|
||||
mlp_bias.expand_as(residual),
|
||||
residual,
|
||||
self.hidden_dropout)
|
||||
|
||||
return output
|
||||
13
examples/tutorial/sequence_parallel/model/layers/dropout.py
Normal file
13
examples/tutorial/sequence_parallel/model/layers/dropout.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
|
||||
def bias_dropout_add(x, bias, residual, prob, training):
|
||||
# type: (Tensor, Tensor, Tensor, float, bool) -> Tensor
|
||||
out = torch.nn.functional.dropout(x + bias, p=prob, training=training)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def get_bias_dropout_add(training):
|
||||
def _bias_dropout_add(x, bias, residual, prob):
|
||||
return bias_dropout_add(x, bias, residual, prob, training)
|
||||
return _bias_dropout_add
|
||||
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
class VocabEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_embeddings, embedding_dim):
|
||||
super(VocabEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
self.weight = nn.Parameter(torch.empty(
|
||||
self.num_embeddings, self.embedding_dim))
|
||||
init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
output = F.embedding(hidden_state, self.weight,
|
||||
self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq,
|
||||
self.sparse)
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'VocabEmbedding(num_embeddings={self.num_embeddings}, ' \
|
||||
f'embedding_dim={self.embedding_dim})'
|
||||
|
||||
|
||||
class Embedding(nn.Module):
|
||||
"""Language model embeddings.
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes):
|
||||
super(Embedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_tokentypes = num_tokentypes
|
||||
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, self.hidden_size)
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(
|
||||
max_sequence_length, self.hidden_size)
|
||||
|
||||
# Token type embedding.
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes,
|
||||
self.hidden_size)
|
||||
else:
|
||||
self.tokentype_embeddings = None
|
||||
|
||||
# Embeddings dropout
|
||||
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
|
||||
|
||||
@property
|
||||
def word_embedding_weight(self):
|
||||
return self.word_embeddings.weight
|
||||
|
||||
def forward(self, input_ids, position_ids, tokentype_ids=None):
|
||||
# Embeddings.
|
||||
words_embeddings = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
embeddings = words_embeddings + position_embeddings
|
||||
if tokentype_ids is not None and self.tokentype_embeddings is not None:
|
||||
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
78
examples/tutorial/sequence_parallel/model/layers/head.py
Normal file
78
examples/tutorial/sequence_parallel/model/layers/head.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .pooler import Pooler
|
||||
from .linear import Linear
|
||||
from .embedding import VocabEmbedding
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.kernel import LayerNorm
|
||||
from loss_func.cross_entropy import vocab_cross_entropy
|
||||
|
||||
|
||||
class BertLMHead(nn.Module):
|
||||
"""Masked LM head for Bert
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
init_method: init method for weight initialization
|
||||
layernorm_epsilon: tolerance for layer norm divisions
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
):
|
||||
|
||||
super(BertLMHead, self).__init__()
|
||||
self.bias = torch.nn.Parameter(torch.zeros(vocab_size))
|
||||
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
self.layernorm = LayerNorm(hidden_size)
|
||||
self.gelu = torch.nn.functional.gelu
|
||||
|
||||
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
hidden_states = self.layernorm(hidden_states)
|
||||
|
||||
output = F.linear(hidden_states, word_embeddings_weight, self.bias)
|
||||
lm_loss = vocab_cross_entropy(output, lm_labels)
|
||||
|
||||
return lm_loss
|
||||
|
||||
|
||||
class BertBinaryHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super().__init__()
|
||||
self.pooler = Pooler(hidden_size)
|
||||
self.dense = Linear(hidden_size, 2)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if gpc.get_local_rank(ParallelMode.SEQUENCE) == 0:
|
||||
output = self.pooler(hidden_states)
|
||||
output = self.dense(output)
|
||||
else:
|
||||
output = None
|
||||
return output
|
||||
|
||||
|
||||
class BertDualHead(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, vocab_size, add_binary_head):
|
||||
super().__init__()
|
||||
self.lm_head = BertLMHead(vocab_size, hidden_size)
|
||||
self.add_binary_head = add_binary_head
|
||||
if add_binary_head:
|
||||
self.binary_head = BertBinaryHead(hidden_size)
|
||||
else:
|
||||
self.binary_head = None
|
||||
|
||||
def forward(self, hidden_states, word_embeddings_weight, lm_labels):
|
||||
if self.add_binary_head:
|
||||
binary_output = self.binary_head(hidden_states)
|
||||
else:
|
||||
binary_output = None
|
||||
lm_loss = self.lm_head(hidden_states, word_embeddings_weight, lm_labels)
|
||||
return lm_loss, binary_output
|
||||
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
import math
|
||||
|
||||
def init_normal(tensor, sigma):
|
||||
"""Init method based on N(0, sigma)."""
|
||||
torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
||||
|
||||
|
||||
def output_init_normal(tensor, sigma, num_layers):
|
||||
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
|
||||
std = sigma / math.sqrt(2.0 * num_layers)
|
||||
torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
||||
63
examples/tutorial/sequence_parallel/model/layers/linear.py
Normal file
63
examples/tutorial/sequence_parallel/model/layers/linear.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Parameter
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
bias: If true, add bias
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimations where bias
|
||||
can be fused with other elementwise operations. we skip
|
||||
adding bias but instead return it.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
output_size,
|
||||
bias=True,
|
||||
skip_bias_add=False):
|
||||
super(Linear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size,
|
||||
))
|
||||
init.normal_(self.weight)
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.output_size))
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def forward(self, input_):
|
||||
# Matrix multiply.
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = F.linear(input_, self.weight, bias)
|
||||
|
||||
if self.skip_bias_add:
|
||||
return output, self.bias
|
||||
else:
|
||||
return output
|
||||
|
||||
def __repr__(self):
|
||||
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
|
||||
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
|
||||
50
examples/tutorial/sequence_parallel/model/layers/mlp.py
Normal file
50
examples/tutorial/sequence_parallel/model/layers/mlp.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .linear import Linear
|
||||
from colossalai.kernel.jit import bias_gelu_impl
|
||||
|
||||
|
||||
class TransformerMLP(nn.Module):
|
||||
"""MLP.
|
||||
MLP will take the input with h hidden state, project it to 4*h
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension. At the end, dropout is also
|
||||
applied.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, mlp_ratio, fuse_gelu=True):
|
||||
super(TransformerMLP, self).__init__()
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = Linear(
|
||||
hidden_size,
|
||||
int(hidden_size*mlp_ratio),
|
||||
skip_bias_add=True)
|
||||
|
||||
self.bias_gelu_fusion = fuse_gelu
|
||||
self.activation_func = F.gelu
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = Linear(
|
||||
int(hidden_size*mlp_ratio),
|
||||
hidden_size,
|
||||
skip_bias_add=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# hidden states should be in the shape of [s, b, h]
|
||||
# it will be projects into [s, b, 4h]
|
||||
# and projected back to [s, b, h]
|
||||
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
if self.bias_gelu_fusion:
|
||||
intermediate_parallel = \
|
||||
bias_gelu_impl(intermediate_parallel, bias_parallel)
|
||||
else:
|
||||
intermediate_parallel = \
|
||||
self.activation_func(intermediate_parallel + bias_parallel)
|
||||
|
||||
# [s, b, h]
|
||||
output, output_bias = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output, output_bias
|
||||
28
examples/tutorial/sequence_parallel/model/layers/pooler.py
Normal file
28
examples/tutorial/sequence_parallel/model/layers/pooler.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .linear import Linear
|
||||
|
||||
|
||||
class Pooler(nn.Module):
|
||||
"""Pooler layer.
|
||||
|
||||
Pool hidden states of a specific token (for example start of the
|
||||
sequence) and add a linear transformation followed by a tanh.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
init_method: weight initialization method for the linear layer.
|
||||
bias is set to zero.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size):
|
||||
super(Pooler, self).__init__()
|
||||
self.dense = Linear(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, hidden_states, sequence_index=0):
|
||||
# hidden_states: [b, s, h]
|
||||
# sequence_index: index of the token to pool.
|
||||
pooled = hidden_states[:, sequence_index, :]
|
||||
pooled = self.dense(pooled)
|
||||
pooled = torch.tanh(pooled)
|
||||
return pooled
|
||||
@@ -0,0 +1,58 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
class PreProcessor(nn.Module):
|
||||
|
||||
def __init__(self, sub_seq_length):
|
||||
super().__init__()
|
||||
self.sub_seq_length = sub_seq_length
|
||||
|
||||
def bert_position_ids(self, token_ids):
|
||||
# Create position ids
|
||||
seq_length = token_ids.size(1)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
position_ids = torch.arange(seq_length*local_rank,
|
||||
seq_length * (local_rank+1),
|
||||
dtype=torch.long,
|
||||
device=token_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand_as(token_ids)
|
||||
|
||||
return position_ids
|
||||
|
||||
def bert_extended_attention_mask(self, attention_mask):
|
||||
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
|
||||
start_index = local_rank * self.sub_seq_length
|
||||
end_index = (local_rank + 1) * self.sub_seq_length
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# [b, 1, s]
|
||||
attention_mask_b1s = attention_mask.unsqueeze(1)
|
||||
# [b, s, 1]
|
||||
attention_mask_bs1 = attention_mask.unsqueeze(2)
|
||||
# [b, s/D, s]
|
||||
attention_mask_bss = attention_mask_b1s * attention_mask_bs1
|
||||
|
||||
attention_mask_bss = attention_mask_bss[:, start_index:end_index, :]
|
||||
|
||||
# [b, 1, s/D, s]
|
||||
extended_attention_mask = attention_mask_bss.unsqueeze(1)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
extended_attention_mask = (extended_attention_mask < 0.5)
|
||||
|
||||
return extended_attention_mask
|
||||
|
||||
def forward(self, input_ids=None, attention_mask=None):
|
||||
if attention_mask is not None:
|
||||
extended_attention_mask = self.bert_extended_attention_mask(attention_mask)
|
||||
else:
|
||||
extended_attention_mask = None
|
||||
|
||||
if input_ids is not None:
|
||||
position_ids = self.bert_position_ids(input_ids)
|
||||
else:
|
||||
position_ids = None
|
||||
return position_ids, extended_attention_mask
|
||||
Reference in New Issue
Block a user