mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-02 07:39:24 +00:00
[tutorial] edited hands-on practices (#1899)
* Add handson to ColossalAI. * Change names of handsons and edit sequence parallel example. * Edit wrong folder name * resolve conflict * delete readme
This commit is contained in:
282
examples/tutorial/sequence_parallel/model/bert.py
Normal file
282
examples/tutorial/sequence_parallel/model/bert.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import inspect
|
||||
from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding
|
||||
from .layers.init_method import init_normal, output_init_normal
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.kernel import LayerNorm
|
||||
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.utils import partition_uniform
|
||||
|
||||
|
||||
class BertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
|
||||
if not add_binary_head:
|
||||
num_tokentypes = 0
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
bert_layer = BertLayer(layer_number=i+1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0),
|
||||
add_binary_head=add_binary_head)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
init_normal(tensor, sigma=self.init_std)
|
||||
|
||||
def _output_init_normal(self, tensor):
|
||||
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize embedding
|
||||
self._init_normal(self.embedding.word_embedding_weight)
|
||||
self._init_normal(self.embedding.position_embeddings.weight)
|
||||
if self.embedding.tokentype_embeddings:
|
||||
self._init_normal(self.embedding.tokentype_embeddings.weight)
|
||||
|
||||
# initialize bert layer
|
||||
for layer in self.bert_layers:
|
||||
# initialize self attention
|
||||
self._init_normal(layer.self_attention.query_key_value.weight)
|
||||
self._output_init_normal(layer.self_attention.dense.weight)
|
||||
self._init_normal(layer.mlp.dense_h_to_4h.weight)
|
||||
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
|
||||
|
||||
# initializer head
|
||||
self._init_normal(self.head.lm_head.dense.weight)
|
||||
if self.head.binary_head is not None:
|
||||
self._init_normal(self.head.binary_head.pooler.dense.weight)
|
||||
self._init_normal(self.head.binary_head.dense.weight)
|
||||
|
||||
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
|
||||
# inputs of the forward function
|
||||
# input_ids: [batch_size, sub_seq_len]
|
||||
# attention_mask: [batch_size, seq_len]
|
||||
# tokentype_ids: [batch_size, sub_seq_len]
|
||||
# outputs of preprocessor
|
||||
# pos_ids: [batch_size, sub_seq_len]
|
||||
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
|
||||
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
|
||||
|
||||
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
|
||||
|
||||
# hidden_states shape change:
|
||||
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
|
||||
for idx, layer in enumerate(self.bert_layers):
|
||||
hidden_states = layer(hidden_states, attention_masks)
|
||||
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.layer_norm(hidden_states)
|
||||
|
||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||
# word_embedding: [vocab_size, hidden_size]
|
||||
return self.head(output, self.embedding.word_embedding_weight, lm_labels)
|
||||
|
||||
|
||||
class PipelineBertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
|
||||
if not add_binary_head:
|
||||
num_tokentypes = 0
|
||||
|
||||
self.first_stage = first_stage
|
||||
self.last_stage = last_stage
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
|
||||
if self.first_stage:
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
|
||||
# transformer layers
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
if start_idx is None and end_idx is None:
|
||||
start_idx = 0
|
||||
end_idx = num_layers
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
bert_layer = BertLayer(layer_number=i+1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
if self.last_stage:
|
||||
self.word_embeddings = VocabEmbedding(vocab_size, hidden_size)
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size, vocab_size,
|
||||
add_binary_head=add_binary_head)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
init_normal(tensor, sigma=self.init_std)
|
||||
|
||||
def _output_init_normal(self, tensor):
|
||||
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
|
||||
|
||||
def reset_parameters(self):
|
||||
# initialize embedding
|
||||
if self.first_stage:
|
||||
self._init_normal(self.embedding.word_embedding_weight)
|
||||
self._init_normal(self.embedding.position_embeddings.weight)
|
||||
if self.embedding.tokentype_embeddings:
|
||||
self._init_normal(self.embedding.tokentype_embeddings.weight)
|
||||
|
||||
# initialize bert layer
|
||||
for layer in self.bert_layers:
|
||||
# initialize self attention
|
||||
self._init_normal(layer.self_attention.query_key_value.weight)
|
||||
self._output_init_normal(layer.self_attention.dense.weight)
|
||||
self._init_normal(layer.mlp.dense_h_to_4h.weight)
|
||||
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
|
||||
|
||||
# initializer head
|
||||
if self.last_stage:
|
||||
self._init_normal(self.head.lm_head.dense.weight)
|
||||
if self.head.binary_head is not None:
|
||||
self._init_normal(self.head.binary_head.pooler.dense.weight)
|
||||
self._init_normal(self.head.binary_head.dense.weight)
|
||||
|
||||
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
|
||||
# inputs of the forward function
|
||||
# input_ids: [batch_size, sub_seq_len]
|
||||
# attention_mask: [batch_size, seq_len]
|
||||
# tokentype_ids: [batch_size, sub_seq_len]
|
||||
# outputs of preprocessor
|
||||
# pos_ids: [batch_size, sub_seq_len]
|
||||
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
|
||||
if self.first_stage:
|
||||
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
|
||||
else:
|
||||
_, attention_masks = self.preprocessor(None, attention_masks)
|
||||
|
||||
if self.first_stage:
|
||||
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
|
||||
# hidden_states shape change:
|
||||
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
|
||||
for idx, layer in enumerate(self.bert_layers):
|
||||
hidden_states = layer(hidden_states, attention_masks)
|
||||
|
||||
if self.last_stage:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.layer_norm(hidden_states)
|
||||
output = self.head(output, self.word_embeddings.weight, lm_labels)
|
||||
else:
|
||||
output = hidden_states
|
||||
|
||||
# hidden_states: [sub_seq_len, batch_size, hidden_size]
|
||||
# word_embedding: [vocab_size, hidden_size]
|
||||
return output
|
||||
|
||||
|
||||
def _filter_kwargs(func, kwargs):
|
||||
sig = inspect.signature(func)
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
rank = gpc.get_global_rank()
|
||||
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
elif end == num_layers:
|
||||
wrapper.register_module(chunk.word_embeddings)
|
||||
models.append(chunk)
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
return model
|
Reference in New Issue
Block a user