mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-05 00:56:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -16,7 +16,6 @@ from .layers.init_method import init_normal, output_init_normal
|
||||
|
||||
|
||||
class BertForPretrain(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
@@ -34,7 +33,9 @@ class BertForPretrain(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
assert (
|
||||
max_sequence_length % self.seq_parallel_size == 0
|
||||
), "sequence length is not divisible by the sequence parallel size"
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
@@ -43,28 +44,32 @@ class BertForPretrain(nn.Module):
|
||||
num_tokentypes = 0
|
||||
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.embedding = Embedding(
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes,
|
||||
)
|
||||
self.bert_layers = nn.ModuleList()
|
||||
|
||||
for i in range(num_layers):
|
||||
bert_layer = BertLayer(layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16)
|
||||
bert_layer = BertLayer(
|
||||
layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
self.layer_norm = LayerNorm(hidden_size)
|
||||
self.head = BertDualHead(hidden_size,
|
||||
self.embedding.word_embedding_weight.size(0),
|
||||
add_binary_head=add_binary_head)
|
||||
self.head = BertDualHead(
|
||||
hidden_size, self.embedding.word_embedding_weight.size(0), add_binary_head=add_binary_head
|
||||
)
|
||||
self.reset_parameters()
|
||||
|
||||
def _init_normal(self, tensor):
|
||||
@@ -122,27 +127,30 @@ class BertForPretrain(nn.Module):
|
||||
|
||||
|
||||
class PipelineBertForPretrain(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size,
|
||||
hidden_size,
|
||||
max_sequence_length,
|
||||
num_attention_heads,
|
||||
num_layers,
|
||||
add_binary_head,
|
||||
is_naive_fp16,
|
||||
num_tokentypes=2,
|
||||
dropout_prob=0.1,
|
||||
mlp_ratio=4,
|
||||
init_std=0.02,
|
||||
convert_fp16_to_fp32_in_softmax=False,
|
||||
first_stage=True,
|
||||
last_stage=True,
|
||||
start_idx=None,
|
||||
end_idx=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
|
||||
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
|
||||
assert (
|
||||
max_sequence_length % self.seq_parallel_size == 0
|
||||
), "sequence length is not divisible by the sequence parallel size"
|
||||
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
|
||||
self.init_std = init_std
|
||||
self.num_layers = num_layers
|
||||
@@ -156,11 +164,13 @@ class PipelineBertForPretrain(nn.Module):
|
||||
self.preprocessor = PreProcessor(self.sub_seq_length)
|
||||
|
||||
if self.first_stage:
|
||||
self.embedding = Embedding(hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes)
|
||||
self.embedding = Embedding(
|
||||
hidden_size=hidden_size,
|
||||
vocab_size=vocab_size,
|
||||
max_sequence_length=max_sequence_length,
|
||||
embedding_dropout_prob=dropout_prob,
|
||||
num_tokentypes=num_tokentypes,
|
||||
)
|
||||
|
||||
# transformer layers
|
||||
self.bert_layers = nn.ModuleList()
|
||||
@@ -170,14 +180,16 @@ class PipelineBertForPretrain(nn.Module):
|
||||
end_idx = num_layers
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
bert_layer = BertLayer(layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16)
|
||||
bert_layer = BertLayer(
|
||||
layer_number=i + 1,
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_dropout=dropout_prob,
|
||||
mlp_ratio=mlp_ratio,
|
||||
hidden_dropout=dropout_prob,
|
||||
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
|
||||
is_naive_fp16=is_naive_fp16,
|
||||
)
|
||||
self.bert_layers.append(bert_layer)
|
||||
|
||||
if self.last_stage:
|
||||
@@ -256,7 +268,7 @@ def _filter_kwargs(func, kwargs):
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def build_pipeline_bert(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
@@ -265,12 +277,12 @@ def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **k
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = num_layers
|
||||
kwargs['start_idx'] = start
|
||||
kwargs['end_idx'] = end
|
||||
kwargs['first_stage'] = start == 0
|
||||
kwargs['last_stage'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
kwargs["num_layers"] = num_layers
|
||||
kwargs["start_idx"] = start
|
||||
kwargs["end_idx"] = end
|
||||
kwargs["first_stage"] = start == 0
|
||||
kwargs["last_stage"] = end == num_layers
|
||||
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
|
||||
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
|
||||
if start == 0:
|
||||
wrapper.register_module(chunk.embedding.word_embeddings)
|
||||
|
Reference in New Issue
Block a user