mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -11,8 +11,10 @@ HIDDEN_SIZE = 768
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||
|
||||
# if you do no want zero, just comment out this dictionary
|
||||
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**5))
|
||||
zero = dict(
|
||||
model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**5),
|
||||
)
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
@@ -27,5 +29,5 @@ model = dict(type=GPT2_small_pipeline_hybrid, checkpoint=True, num_chunks=1)
|
||||
# for the current model implementation, mode can only be 1D or None
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=2, mode='1d'),
|
||||
tensor=dict(size=2, mode="1d"),
|
||||
)
|
||||
|
@@ -11,8 +11,10 @@ HIDDEN_SIZE = 12288
|
||||
TENSOR_SHAPE = (BATCH_SIZE // NUM_MICRO_BATCHES, SEQ_LEN, HIDDEN_SIZE)
|
||||
|
||||
# if you do no want zero, just comment out this dictionary
|
||||
zero = dict(model_config=dict(tensor_placement_policy='cuda', shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**16))
|
||||
zero = dict(
|
||||
model_config=dict(tensor_placement_policy="cuda", shard_strategy=TensorShardStrategy()),
|
||||
optimizer_config=dict(initial_scale=2**16),
|
||||
)
|
||||
|
||||
optimizer = dict(
|
||||
type=HybridAdam,
|
||||
@@ -27,5 +29,5 @@ model = dict(type=GPT3_pipeline_hybrid, checkpoint=True, num_chunks=1)
|
||||
# for the current model implementation, mode can only be 1D or None
|
||||
parallel = dict(
|
||||
pipeline=1,
|
||||
tensor=dict(size=2, mode='1d'), # for the current model implementation, mode can only be 1D or None
|
||||
tensor=dict(size=2, mode="1d"), # for the current model implementation, mode can only be 1D or None
|
||||
)
|
||||
|
@@ -11,12 +11,11 @@ from colossalai.legacy.registry import DATASETS
|
||||
|
||||
@DATASETS.register_module
|
||||
class WebtextDataset(Dataset):
|
||||
|
||||
def __init__(self, path: Optional[str] = None, seq_len=1024) -> None:
|
||||
super().__init__()
|
||||
if path is not None:
|
||||
root = os.path.dirname(path)
|
||||
encoded_data_cache_path = os.path.join(root, f'gpt_webtext_{seq_len}.pt')
|
||||
encoded_data_cache_path = os.path.join(root, f"gpt_webtext_{seq_len}.pt")
|
||||
if os.path.isfile(encoded_data_cache_path):
|
||||
seq_len_, data, attention_mask = torch.load(encoded_data_cache_path)
|
||||
if seq_len_ == seq_len:
|
||||
@@ -26,12 +25,12 @@ class WebtextDataset(Dataset):
|
||||
raw_data = []
|
||||
with open(path) as f:
|
||||
for line in f.readlines():
|
||||
raw_data.append(json.loads(line)['text'])
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
raw_data.append(json.loads(line)["text"])
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.unk_token
|
||||
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors='pt')
|
||||
self.data = encoded_data['input_ids']
|
||||
self.attention_mask = encoded_data['attention_mask']
|
||||
encoded_data = tokenizer(raw_data, padding=True, truncation=True, max_length=seq_len, return_tensors="pt")
|
||||
self.data = encoded_data["input_ids"]
|
||||
self.attention_mask = encoded_data["attention_mask"]
|
||||
else:
|
||||
self.data = torch.randint(0, 50257, (10240, seq_len))
|
||||
self.attention_mask = torch.ones_like(self.data)
|
||||
@@ -40,4 +39,4 @@ class WebtextDataset(Dataset):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return {'input_ids': self.data[index], 'attention_mask': self.attention_mask[index]}, self.data[index]
|
||||
return {"input_ids": self.data[index], "attention_mask": self.attention_mask[index]}, self.data[index]
|
||||
|
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
from torch import distributed as dist
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -12,7 +11,7 @@ from colossalai.legacy.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
|
||||
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
|
||||
from colossalai.legacy.nn.layer.utils import divide
|
||||
from colossalai.legacy.registry import LAYERS, LOSSES, MODELS
|
||||
from colossalai.legacy.registry import LAYERS, LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@@ -30,13 +29,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes=0,
|
||||
dtype=torch.float):
|
||||
def __init__(
|
||||
self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes=0, dtype=torch.float
|
||||
):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
@@ -44,11 +39,11 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
self._word_embeddings_key = "word_embeddings"
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
self._position_embeddings_key = "position_embeddings"
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
@@ -56,7 +51,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
self._tokentype_embeddings_key = "tokentype_embeddings"
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -83,9 +78,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
raise Exception("tokentype embeddings is already initialized")
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -112,19 +107,16 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -138,9 +130,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "word_embeddings" in key:
|
||||
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
@@ -150,9 +141,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "position_embeddings" in key:
|
||||
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
@@ -163,15 +153,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "tokentype_embeddings" in key:
|
||||
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
print(
|
||||
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
|
||||
)
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
@@ -193,37 +182,41 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = gpc.tensor_parallel_size
|
||||
# Divide the weight matrix along the vocabulary dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = \
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
|
||||
self.tensor_model_parallel_size)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), self.tensor_model_parallel_size
|
||||
)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | \
|
||||
(input_ >= self.vocab_end_index)
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
@@ -234,7 +227,6 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
|
||||
@LOSSES.register_module
|
||||
class vocab_parallel_cross_entropy(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -242,20 +234,19 @@ class vocab_parallel_cross_entropy(nn.Module):
|
||||
"""Helper function for the cross entropy."""
|
||||
vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
|
||||
target = target[..., 1:].contiguous()
|
||||
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
|
||||
target.view(-1))
|
||||
return _VocabParallelCrossEntropy.apply(
|
||||
vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), target.view(-1)
|
||||
)
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(
|
||||
logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
)
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
@@ -282,17 +273,17 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(
|
||||
predicted_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
)
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(
|
||||
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
@@ -304,7 +295,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
# Retrieve tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
@@ -316,7 +306,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
@@ -326,8 +316,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
||||
@@ -393,11 +383,11 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
self._word_embeddings_key = "word_embeddings"
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
self._position_embeddings_key = "position_embeddings"
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
@@ -405,7 +395,7 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
self._tokentype_embeddings_key = "tokentype_embeddings"
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -432,9 +422,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
raise Exception("tokentype embeddings is already initialized")
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -460,19 +450,16 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -486,9 +473,8 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "word_embeddings" in key:
|
||||
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
@@ -498,9 +484,8 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "position_embeddings" in key:
|
||||
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
@@ -511,15 +496,14 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "tokentype_embeddings" in key:
|
||||
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
print(
|
||||
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
|
||||
)
|
||||
|
||||
|
||||
class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
@@ -542,21 +526,21 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
output_parallel = F.embedding(
|
||||
input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
||||
)
|
||||
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
@@ -584,11 +568,9 @@ class HiddenParallelGPTLMHead1D(ParallelLayer):
|
||||
# self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
# (hidden_size/q, vocab_size)
|
||||
self.synced_embed = False
|
||||
self.head = Linear1D_Row(in_features=embed_dim,
|
||||
out_features=vocab_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_input=False)
|
||||
self.head = Linear1D_Row(
|
||||
in_features=embed_dim, out_features=vocab_size, bias=False, dtype=dtype, parallel_input=False
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.synced_embed:
|
||||
|
@@ -18,18 +18,21 @@ from colossalai.legacy.utils.activation_checkpoint import checkpoint
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
__all__ = [
|
||||
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
|
||||
"GPTMLP1D",
|
||||
"GPTSelfAttention1D",
|
||||
"GPTTransformerLayer1D",
|
||||
"FusedGPTSelfAttention1D",
|
||||
"FusedGPTTransformerLayer1D",
|
||||
]
|
||||
|
||||
|
||||
class GPTMLP1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
act_func: str = "gelu",
|
||||
dropout_prob: float = 0.0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
@@ -82,7 +85,6 @@ class GPTMLP1D(ParallelLayer):
|
||||
|
||||
|
||||
class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -118,8 +120,10 @@ class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.attention_head_size,
|
||||
)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
|
||||
@@ -152,28 +156,32 @@ class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
|
||||
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
max_positions = max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions),
|
||||
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
|
||||
@@ -181,7 +189,7 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# causal mask
|
||||
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
@@ -191,50 +199,56 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
|
||||
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
|
||||
input_in_bf16=False,
|
||||
attn_mask_type=AttnMaskType.causal,
|
||||
scaled_masked_softmax_fusion=True,
|
||||
mask_func=None,
|
||||
softmax_in_fp32=True,
|
||||
scale=math.sqrt(self.attention_head_size))
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.softmax = kernel.FusedScaleMaskSoftmax(
|
||||
input_in_fp16=True,
|
||||
input_in_bf16=False,
|
||||
attn_mask_type=AttnMaskType.causal,
|
||||
scaled_masked_softmax_fusion=True,
|
||||
mask_func=None,
|
||||
softmax_in_fp32=True,
|
||||
scale=math.sqrt(self.attention_head_size),
|
||||
)
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
return self.softmax(attention_scores, attention_mask)
|
||||
|
||||
|
||||
class GenericGPTTransformerLayer1D(ParallelLayer):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
attention=None,
|
||||
layer_norm=None):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.0,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
attention=None,
|
||||
layer_norm=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.dtype = dtype
|
||||
@@ -288,62 +302,68 @@ class GenericGPTTransformerLayer1D(ParallelLayer):
|
||||
|
||||
|
||||
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False,
|
||||
):
|
||||
attention = GPTSelfAttention1D
|
||||
layer_norm = nn.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm,
|
||||
)
|
||||
|
||||
|
||||
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False,
|
||||
):
|
||||
attention = FusedGPTSelfAttention1D
|
||||
layer_norm = kernel.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm,
|
||||
)
|
||||
|
@@ -17,17 +17,16 @@ from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabPara
|
||||
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
|
||||
|
||||
__all__ = [
|
||||
'GPT2_small_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_1D',
|
||||
'GPT3_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_hybrid',
|
||||
'GPT2_small_pipeline_hybrid',
|
||||
'GPT3_pipeline_hybrid',
|
||||
"GPT2_small_pipeline_1D",
|
||||
"GPT2_exlarge_pipeline_1D",
|
||||
"GPT3_pipeline_1D",
|
||||
"GPT2_exlarge_pipeline_hybrid",
|
||||
"GPT2_small_pipeline_hybrid",
|
||||
"GPT3_pipeline_hybrid",
|
||||
]
|
||||
|
||||
|
||||
class GenericPipelineGPT(nn.Module):
|
||||
|
||||
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
|
||||
super().__init__()
|
||||
self.embedding = embedding
|
||||
@@ -44,7 +43,7 @@ class GenericPipelineGPT(nn.Module):
|
||||
batch_size = hidden_states.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
@@ -54,25 +53,26 @@ class GenericPipelineGPT(nn.Module):
|
||||
|
||||
|
||||
class PipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.0,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False,
|
||||
):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
@@ -83,19 +83,24 @@ class PipelineGPT1D(GenericPipelineGPT):
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
GPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
blocks = nn.ModuleList(
|
||||
[
|
||||
GPTTransformerLayer1D(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
@@ -103,25 +108,26 @@ class PipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
|
||||
class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.0,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False,
|
||||
):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
@@ -132,19 +138,24 @@ class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
FusedGPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
blocks = nn.ModuleList(
|
||||
[
|
||||
FusedGPTTransformerLayer1D(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
@@ -153,7 +164,7 @@ class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
@@ -162,44 +173,48 @@ class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
|
||||
class PipelineGPTHybrid(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.0,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False,
|
||||
):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
if first:
|
||||
embedding = col_gpt.GPTEmbedding(hidden_size,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
dropout=embed_drop_rate,
|
||||
dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
col_gpt.GPTBlock(hidden_size,
|
||||
num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
activation=nn.functional.gelu) for _ in range(num_layers)
|
||||
])
|
||||
embedding = col_gpt.GPTEmbedding(
|
||||
hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype
|
||||
)
|
||||
blocks = nn.ModuleList(
|
||||
[
|
||||
col_gpt.GPTBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
activation=nn.functional.gelu,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
|
||||
@@ -215,7 +230,7 @@ def _filter_kwargs(func, kwargs):
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
@@ -233,10 +248,10 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = end - start
|
||||
kwargs['first'] = start == 0
|
||||
kwargs['last'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
kwargs["num_layers"] = end - start
|
||||
kwargs["first"] = start == 0
|
||||
kwargs["last"] = end == num_layers
|
||||
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
|
||||
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
|
||||
|
||||
if wrapper is not None:
|
||||
@@ -253,70 +268,82 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to
|
||||
numel = 0
|
||||
for _, param in model.named_parameters(recurse=True):
|
||||
numel += param.numel()
|
||||
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
|
||||
logger.info(f"Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB")
|
||||
return model
|
||||
|
||||
|
||||
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
|
||||
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device("cuda"), fused=False, **kwargs):
|
||||
model = FusedPipelineGPT1D if fused else PipelineGPT1D
|
||||
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
|
||||
|
@@ -14,7 +14,7 @@ from colossalai.legacy.trainer import Trainer, hooks
|
||||
from colossalai.legacy.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn import LinearWarmupLR
|
||||
from colossalai.utils import colo_set_process_memory_fraction, is_using_pp
|
||||
from colossalai.utils import is_using_pp
|
||||
from colossalai.utils.timer import MultiTimer
|
||||
|
||||
|
||||
@@ -30,8 +30,8 @@ VOCAB_SIZE = 50257
|
||||
|
||||
def main():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument('--from_torch', default=False, action='store_true')
|
||||
parser.add_argument('--use_dummy_dataset', default=False, action='store_true')
|
||||
parser.add_argument("--from_torch", default=False, action="store_true")
|
||||
parser.add_argument("--use_dummy_dataset", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
disable_existing_loggers()
|
||||
if args.from_torch:
|
||||
@@ -40,28 +40,27 @@ def main():
|
||||
colossalai.launch_from_slurm(config=args.config, host=args.host, port=29500, seed=42)
|
||||
logger = get_dist_logger()
|
||||
|
||||
data_path = None if args.use_dummy_dataset else os.environ['DATA']
|
||||
logger.info(f'Build data loader from path {data_path}', ranks=[0])
|
||||
data_path = None if args.use_dummy_dataset else os.environ["DATA"]
|
||||
logger.info(f"Build data loader from path {data_path}", ranks=[0])
|
||||
|
||||
train_ds = WebtextDataset(path=data_path, seq_len=gpc.config.SEQ_LEN)
|
||||
train_dataloader = utils.get_dataloader(train_ds,
|
||||
seed=42,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
shuffle=True,
|
||||
drop_last=True)
|
||||
train_dataloader = utils.get_dataloader(
|
||||
train_ds, seed=42, batch_size=gpc.config.BATCH_SIZE, pin_memory=True, shuffle=True, drop_last=True
|
||||
)
|
||||
|
||||
logger.info('Build model', ranks=[0])
|
||||
logger.info("Build model", ranks=[0])
|
||||
use_pipeline = is_using_pp()
|
||||
use_interleaved = hasattr(gpc.config.model, 'num_chunks')
|
||||
use_zero3 = hasattr(gpc.config, 'zero')
|
||||
use_interleaved = hasattr(gpc.config.model, "num_chunks")
|
||||
use_zero3 = hasattr(gpc.config, "zero")
|
||||
ctx = contextlib.nullcontext()
|
||||
if use_zero3:
|
||||
ctx = ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True)
|
||||
ctx = ZeroInitContext(
|
||||
target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True,
|
||||
)
|
||||
with ctx:
|
||||
model = gpc.config.model.pop('type')(**gpc.config.model)
|
||||
model = gpc.config.model.pop("type")(**gpc.config.model)
|
||||
if use_pipeline and use_interleaved and not isinstance(model, nn.ModuleList):
|
||||
model = nn.ModuleList([model])
|
||||
|
||||
@@ -70,25 +69,31 @@ def main():
|
||||
else:
|
||||
numel = calc_local_model_size(model)
|
||||
|
||||
tflop = numel * gpc.config.BATCH_SIZE * gpc.config.SEQ_LEN \
|
||||
* gpc.get_world_size(ParallelMode.MODEL) * gpc.get_world_size(ParallelMode.DATA) * 8 / (1024 ** 4)
|
||||
tflop = (
|
||||
numel
|
||||
* gpc.config.BATCH_SIZE
|
||||
* gpc.config.SEQ_LEN
|
||||
* gpc.get_world_size(ParallelMode.MODEL)
|
||||
* gpc.get_world_size(ParallelMode.DATA)
|
||||
* 8
|
||||
/ (1024**4)
|
||||
)
|
||||
|
||||
criterion = getattr(gpc.config, 'loss_fn', None)
|
||||
criterion = getattr(gpc.config, "loss_fn", None)
|
||||
if criterion is not None:
|
||||
criterion = criterion.type()
|
||||
else:
|
||||
criterion = GPTLMLoss()
|
||||
logger.info('Build optimizer', ranks=[0])
|
||||
optimizer = gpc.config.optimizer.pop('type')(model.parameters(), **gpc.config.optimizer)
|
||||
logger.info("Build optimizer", ranks=[0])
|
||||
optimizer = gpc.config.optimizer.pop("type")(model.parameters(), **gpc.config.optimizer)
|
||||
lr_scheduler = LinearWarmupLR(optimizer, total_steps=gpc.config.NUM_EPOCHS, warmup_steps=5)
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(model,
|
||||
optimizer,
|
||||
criterion,
|
||||
train_dataloader=train_dataloader,
|
||||
lr_scheduler=lr_scheduler)
|
||||
global_batch_size = gpc.config.BATCH_SIZE * \
|
||||
gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
logger.info(f'Init done, global batch size = {global_batch_size}', ranks=[0])
|
||||
engine, train_dataloader, _, lr_scheduler = colossalai.initialize(
|
||||
model, optimizer, criterion, train_dataloader=train_dataloader, lr_scheduler=lr_scheduler
|
||||
)
|
||||
global_batch_size = (
|
||||
gpc.config.BATCH_SIZE * gpc.get_world_size(ParallelMode.DATA) * getattr(gpc.config, "gradient_accumulation", 1)
|
||||
)
|
||||
logger.info(f"Init done, global batch size = {global_batch_size}", ranks=[0])
|
||||
timier = MultiTimer()
|
||||
trainer = Trainer(engine=engine, logger=logger, timer=timier)
|
||||
hook_list = [
|
||||
@@ -98,16 +103,18 @@ def main():
|
||||
hooks.ThroughputHook(ignored_steps=10, tflop_per_step=tflop),
|
||||
hooks.LogMetricByStepHook(),
|
||||
hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogTimingByEpochHook(timer, logger),
|
||||
# hooks.LogMemoryByEpochHook(logger),
|
||||
# hooks.LogTimingByEpochHook(timer, logger),
|
||||
]
|
||||
trainer.fit(train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False)
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
epochs=gpc.config.NUM_EPOCHS,
|
||||
test_interval=1,
|
||||
hooks=hook_list,
|
||||
display_progress=True,
|
||||
return_output_label=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user