mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 01:55:12 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn.init as init
|
||||
from torch import Tensor
|
||||
from torch import distributed as dist
|
||||
from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -12,7 +11,7 @@ from colossalai.legacy.nn.layer.base_layer import ParallelLayer
|
||||
from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
|
||||
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
|
||||
from colossalai.legacy.nn.layer.utils import divide
|
||||
from colossalai.legacy.registry import LAYERS, LOSSES, MODELS
|
||||
from colossalai.legacy.registry import LAYERS, LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
@@ -30,13 +29,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
max_sequence_length,
|
||||
embedding_dropout_prob,
|
||||
num_tokentypes=0,
|
||||
dtype=torch.float):
|
||||
def __init__(
|
||||
self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes=0, dtype=torch.float
|
||||
):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
@@ -44,11 +39,11 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
self._word_embeddings_key = "word_embeddings"
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
self._position_embeddings_key = "position_embeddings"
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
@@ -56,7 +51,7 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
self._tokentype_embeddings_key = "tokentype_embeddings"
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -83,9 +78,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
raise Exception("tokentype embeddings is already initialized")
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -112,19 +107,16 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -138,9 +130,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "word_embeddings" in key:
|
||||
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
@@ -150,9 +141,8 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "position_embeddings" in key:
|
||||
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
@@ -163,15 +153,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "tokentype_embeddings" in key:
|
||||
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
print(
|
||||
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
|
||||
)
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
@@ -193,37 +182,41 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = gpc.tensor_parallel_size
|
||||
# Divide the weight matrix along the vocabulary dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = \
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
|
||||
self.tensor_model_parallel_size)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), self.tensor_model_parallel_size
|
||||
)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | \
|
||||
(input_ >= self.vocab_end_index)
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
output_parallel = F.embedding(
|
||||
masked_input,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
@@ -234,7 +227,6 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
|
||||
@LOSSES.register_module
|
||||
class vocab_parallel_cross_entropy(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -242,20 +234,19 @@ class vocab_parallel_cross_entropy(nn.Module):
|
||||
"""Helper function for the cross entropy."""
|
||||
vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
|
||||
target = target[..., 1:].contiguous()
|
||||
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
|
||||
target.view(-1))
|
||||
return _VocabParallelCrossEntropy.apply(
|
||||
vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), target.view(-1)
|
||||
)
|
||||
|
||||
|
||||
class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_parallel_logits, target):
|
||||
|
||||
# Maximum value along vocab dimension across all GPUs.
|
||||
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
|
||||
torch.distributed.all_reduce(logits_max,
|
||||
op=torch.distributed.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(
|
||||
logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
)
|
||||
# Subtract the maximum value.
|
||||
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
|
||||
|
||||
@@ -282,17 +273,17 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
predicted_logits = predicted_logits_1d.view_as(target)
|
||||
predicted_logits[target_mask] = 0.0
|
||||
# All reduce is needed to get the chunks from other GPUs.
|
||||
torch.distributed.all_reduce(predicted_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(
|
||||
predicted_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
)
|
||||
|
||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||
exp_logits = vocab_parallel_logits
|
||||
torch.exp(vocab_parallel_logits, out=exp_logits)
|
||||
sum_exp_logits = exp_logits.sum(dim=-1)
|
||||
torch.distributed.all_reduce(sum_exp_logits,
|
||||
op=torch.distributed.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.all_reduce(
|
||||
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D)
|
||||
)
|
||||
|
||||
# Loss = log(sum(exp(logits))) - predicted-logit.
|
||||
loss = torch.log(sum_exp_logits) - predicted_logits
|
||||
@@ -304,7 +295,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
# Retrieve tensors from the forward path.
|
||||
softmax, target_mask, masked_target_1d = ctx.saved_tensors
|
||||
|
||||
@@ -316,7 +306,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
# Add the gradient from matching classes.
|
||||
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
|
||||
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
|
||||
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
|
||||
|
||||
# Finally elementwise multiplication with the output gradients.
|
||||
grad_input.mul_(grad_output.unsqueeze(dim=-1))
|
||||
@@ -326,8 +316,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
|
||||
|
||||
class VocabUtility:
|
||||
"""Split the vocabulary into `world_size` chunks amd return the
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
first and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
|
||||
@@ -393,11 +383,11 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
self._word_embeddings_key = 'word_embeddings'
|
||||
self._word_embeddings_key = "word_embeddings"
|
||||
|
||||
# Position embedding (serial).
|
||||
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
|
||||
self._position_embeddings_key = 'position_embeddings'
|
||||
self._position_embeddings_key = "position_embeddings"
|
||||
# Initialize the position embeddings.
|
||||
# self.init_method(self.position_embeddings.weight)
|
||||
|
||||
@@ -405,7 +395,7 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
# Add this as an optional field that can be added through
|
||||
# method call so we can load a pretrain model without
|
||||
# token types and add them as needed.
|
||||
self._tokentype_embeddings_key = 'tokentype_embeddings'
|
||||
self._tokentype_embeddings_key = "tokentype_embeddings"
|
||||
if self.num_tokentypes > 0:
|
||||
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -432,9 +422,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
This allows us to load the model normally and then add this embedding.
|
||||
"""
|
||||
if self.tokentype_embeddings is not None:
|
||||
raise Exception('tokentype embeddings is already initialized')
|
||||
raise Exception("tokentype embeddings is already initialized")
|
||||
if torch.distributed.get_rank() == 0:
|
||||
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
|
||||
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
|
||||
self.num_tokentypes = num_tokentypes
|
||||
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
|
||||
# Initialize the token-type embeddings.
|
||||
@@ -460,19 +450,16 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
return embeddings
|
||||
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
|
||||
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
|
||||
"""For easy load."""
|
||||
|
||||
state_dict_ = {}
|
||||
state_dict_[self._word_embeddings_key] \
|
||||
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] \
|
||||
= self.position_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
|
||||
if self.num_tokentypes > 0:
|
||||
state_dict_[self._tokentype_embeddings_key] \
|
||||
= self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars)
|
||||
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
|
||||
destination, prefix, keep_vars
|
||||
)
|
||||
|
||||
return state_dict_
|
||||
|
||||
@@ -486,9 +473,8 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'word_embeddings' in key:
|
||||
state_dict_[key.split('word_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "word_embeddings" in key:
|
||||
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
|
||||
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Position embedding.
|
||||
@@ -498,9 +484,8 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
# for backward compatibility.
|
||||
state_dict_ = {}
|
||||
for key in state_dict.keys():
|
||||
if 'position_embeddings' in key:
|
||||
state_dict_[key.split('position_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "position_embeddings" in key:
|
||||
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
|
||||
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
|
||||
# Tokentype embedding.
|
||||
@@ -511,15 +496,14 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
else:
|
||||
# for backward compatibility.
|
||||
for key in state_dict.keys():
|
||||
if 'tokentype_embeddings' in key:
|
||||
state_dict_[key.split('tokentype_embeddings.')[1]] \
|
||||
= state_dict[key]
|
||||
if "tokentype_embeddings" in key:
|
||||
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
|
||||
if len(state_dict_.keys()) > 0:
|
||||
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
|
||||
else:
|
||||
print('***WARNING*** expected tokentype embeddings in the '
|
||||
'checkpoint but could not find it',
|
||||
flush=True)
|
||||
print(
|
||||
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
|
||||
)
|
||||
|
||||
|
||||
class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
@@ -542,21 +526,21 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
# Set the details for compatibility.
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.norm_type = 2.0
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
|
||||
self.scale_grad_by_freq, self.sparse)
|
||||
output_parallel = F.embedding(
|
||||
input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
|
||||
)
|
||||
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
@@ -584,11 +568,9 @@ class HiddenParallelGPTLMHead1D(ParallelLayer):
|
||||
# self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
|
||||
# (hidden_size/q, vocab_size)
|
||||
self.synced_embed = False
|
||||
self.head = Linear1D_Row(in_features=embed_dim,
|
||||
out_features=vocab_size,
|
||||
bias=False,
|
||||
dtype=dtype,
|
||||
parallel_input=False)
|
||||
self.head = Linear1D_Row(
|
||||
in_features=embed_dim, out_features=vocab_size, bias=False, dtype=dtype, parallel_input=False
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
if self.synced_embed:
|
||||
|
@@ -18,18 +18,21 @@ from colossalai.legacy.utils.activation_checkpoint import checkpoint
|
||||
from colossalai.utils import checkpoint
|
||||
|
||||
__all__ = [
|
||||
'GPTMLP1D', 'GPTSelfAttention1D', 'GPTTransformerLayer1D', 'FusedGPTSelfAttention1D', 'FusedGPTTransformerLayer1D'
|
||||
"GPTMLP1D",
|
||||
"GPTSelfAttention1D",
|
||||
"GPTTransformerLayer1D",
|
||||
"FusedGPTSelfAttention1D",
|
||||
"FusedGPTTransformerLayer1D",
|
||||
]
|
||||
|
||||
|
||||
class GPTMLP1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
mlp_ratio: int,
|
||||
act_func: str = 'gelu',
|
||||
dropout_prob: float = 0.,
|
||||
act_func: str = "gelu",
|
||||
dropout_prob: float = 0.0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
@@ -82,7 +85,6 @@ class GPTMLP1D(ParallelLayer):
|
||||
|
||||
|
||||
class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
@@ -118,8 +120,10 @@ class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
def _forward(self, hidden_states: Tensor, attention_mask=None) -> Tensor:
|
||||
query_key_value = self.query_key_value(hidden_states)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + \
|
||||
(self.num_attention_heads_per_partition, 3 * self.attention_head_size)
|
||||
new_qkv_shape = query_key_value.shape[:-1] + (
|
||||
self.num_attention_heads_per_partition,
|
||||
3 * self.attention_head_size,
|
||||
)
|
||||
query_key_value = query_key_value.view(new_qkv_shape)
|
||||
query_key_value = query_key_value.permute((0, 2, 1, 3))
|
||||
query_layer, key_layer, value_layer = torch.chunk(query_key_value, 3, dim=-1)
|
||||
@@ -152,28 +156,32 @@ class GenericGPTSelfAttention1D(ParallelLayer):
|
||||
|
||||
|
||||
class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
max_positions = max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions),
|
||||
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
||||
|
||||
@@ -181,7 +189,7 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
||||
# causal mask
|
||||
query_length, key_length = query_layer.size(-2), key_layer.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool()
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
|
||||
attention_scores = torch.where(causal_mask, attention_scores, self.masked_bias.to(attention_scores))
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
@@ -191,50 +199,56 @@ class GPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
|
||||
class FusedGPTSelfAttention1D(GenericGPTSelfAttention1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024):
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings)
|
||||
self.softmax = kernel.FusedScaleMaskSoftmax(input_in_fp16=True,
|
||||
input_in_bf16=False,
|
||||
attn_mask_type=AttnMaskType.causal,
|
||||
scaled_masked_softmax_fusion=True,
|
||||
mask_func=None,
|
||||
softmax_in_fp32=True,
|
||||
scale=math.sqrt(self.attention_head_size))
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
attention_dropout_prob: float,
|
||||
hidden_dropout_prob: float,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings=1024,
|
||||
):
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
attention_dropout_prob,
|
||||
hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
)
|
||||
self.softmax = kernel.FusedScaleMaskSoftmax(
|
||||
input_in_fp16=True,
|
||||
input_in_bf16=False,
|
||||
attn_mask_type=AttnMaskType.causal,
|
||||
scaled_masked_softmax_fusion=True,
|
||||
mask_func=None,
|
||||
softmax_in_fp32=True,
|
||||
scale=math.sqrt(self.attention_head_size),
|
||||
)
|
||||
|
||||
def softmax_forward(self, attention_scores, attention_mask, query_layer, key_layer):
|
||||
return self.softmax(attention_scores, attention_mask)
|
||||
|
||||
|
||||
class GenericGPTTransformerLayer1D(ParallelLayer):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.,
|
||||
hidden_dropout_prob: float = 0.,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
attention=None,
|
||||
layer_norm=None):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: float = 4.0,
|
||||
attention_dropout_prob: float = 0.0,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
attention=None,
|
||||
layer_norm=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.dtype = dtype
|
||||
@@ -288,62 +302,68 @@ class GenericGPTTransformerLayer1D(ParallelLayer):
|
||||
|
||||
|
||||
class GPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False,
|
||||
):
|
||||
attention = GPTSelfAttention1D
|
||||
layer_norm = nn.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm,
|
||||
)
|
||||
|
||||
|
||||
class FusedGPTTransformerLayer1D(GenericGPTTransformerLayer1D):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: float = 4,
|
||||
attention_dropout_prob: float = 0,
|
||||
hidden_dropout_prob: float = 0,
|
||||
dtype=None,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 0.00001,
|
||||
apply_post_layer_norm: bool = False,
|
||||
):
|
||||
attention = FusedGPTSelfAttention1D
|
||||
layer_norm = kernel.LayerNorm
|
||||
super().__init__(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm)
|
||||
super().__init__(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attention_dropout_prob,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
attention=attention,
|
||||
layer_norm=layer_norm,
|
||||
)
|
||||
|
@@ -17,17 +17,16 @@ from .embed import HiddenParallelEmbedding, HiddenParallelGPTLMHead1D, VocabPara
|
||||
from .gpt1d import FusedGPTTransformerLayer1D, GPTTransformerLayer1D
|
||||
|
||||
__all__ = [
|
||||
'GPT2_small_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_1D',
|
||||
'GPT3_pipeline_1D',
|
||||
'GPT2_exlarge_pipeline_hybrid',
|
||||
'GPT2_small_pipeline_hybrid',
|
||||
'GPT3_pipeline_hybrid',
|
||||
"GPT2_small_pipeline_1D",
|
||||
"GPT2_exlarge_pipeline_1D",
|
||||
"GPT3_pipeline_1D",
|
||||
"GPT2_exlarge_pipeline_hybrid",
|
||||
"GPT2_small_pipeline_hybrid",
|
||||
"GPT3_pipeline_hybrid",
|
||||
]
|
||||
|
||||
|
||||
class GenericPipelineGPT(nn.Module):
|
||||
|
||||
def __init__(self, embedding=None, blocks=None, norm=None, head=None) -> None:
|
||||
super().__init__()
|
||||
self.embedding = embedding
|
||||
@@ -44,7 +43,7 @@ class GenericPipelineGPT(nn.Module):
|
||||
batch_size = hidden_states.shape[0]
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
@@ -54,25 +53,26 @@ class GenericPipelineGPT(nn.Module):
|
||||
|
||||
|
||||
class PipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.0,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False,
|
||||
):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
@@ -83,19 +83,24 @@ class PipelineGPT1D(GenericPipelineGPT):
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
GPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
blocks = nn.ModuleList(
|
||||
[
|
||||
GPTTransformerLayer1D(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
@@ -103,25 +108,26 @@ class PipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
|
||||
class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.0,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False,
|
||||
):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
@@ -132,19 +138,24 @@ class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
head_cls = HiddenParallelGPTLMHead1D
|
||||
if first:
|
||||
embedding = embed_cls(hidden_size, vocab_size, max_position_embeddings, embed_drop_rate, dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
FusedGPTTransformerLayer1D(hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm) for _ in range(num_layers)
|
||||
])
|
||||
blocks = nn.ModuleList(
|
||||
[
|
||||
FusedGPTTransformerLayer1D(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
act_func=act_func,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout_prob=attn_drop_rate,
|
||||
hidden_dropout_prob=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
apply_post_layer_norm=apply_post_layer_norm,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
norm = kernel.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
head = head_cls(vocab_size=vocab_size, embed_dim=hidden_size, dtype=dtype)
|
||||
@@ -153,7 +164,7 @@ class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
def forward(self, hidden_states=None, input_ids=None, attention_mask=None):
|
||||
if self.embedding is not None:
|
||||
hidden_states = self.embedding(input_ids=input_ids)
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=hidden_states.dtype) # fp16 compatibility
|
||||
for block in self.blocks:
|
||||
hidden_states, attention_mask = block(hidden_states, attention_mask)
|
||||
if self.norm is not None:
|
||||
@@ -162,44 +173,48 @@ class FusedPipelineGPT1D(GenericPipelineGPT):
|
||||
|
||||
|
||||
class PipelineGPTHybrid(GenericPipelineGPT):
|
||||
|
||||
def __init__(self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.,
|
||||
act_func: str = 'gelu',
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.,
|
||||
drop_rate: float = 0.,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int = 12,
|
||||
hidden_size: int = 768,
|
||||
num_attention_heads: int = 12,
|
||||
vocab_size: int = 50304,
|
||||
embed_drop_rate: float = 0.0,
|
||||
act_func: str = "gelu",
|
||||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: bool = False,
|
||||
max_position_embeddings: int = 1024,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
apply_post_layer_norm: bool = False,
|
||||
first: bool = False,
|
||||
last: bool = False,
|
||||
embed_split_hidden=False,
|
||||
):
|
||||
embedding = None
|
||||
norm = None
|
||||
head = None
|
||||
if first:
|
||||
embedding = col_gpt.GPTEmbedding(hidden_size,
|
||||
vocab_size,
|
||||
max_position_embeddings,
|
||||
dropout=embed_drop_rate,
|
||||
dtype=dtype)
|
||||
blocks = nn.ModuleList([
|
||||
col_gpt.GPTBlock(hidden_size,
|
||||
num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
activation=nn.functional.gelu) for _ in range(num_layers)
|
||||
])
|
||||
embedding = col_gpt.GPTEmbedding(
|
||||
hidden_size, vocab_size, max_position_embeddings, dropout=embed_drop_rate, dtype=dtype
|
||||
)
|
||||
blocks = nn.ModuleList(
|
||||
[
|
||||
col_gpt.GPTBlock(
|
||||
hidden_size,
|
||||
num_attention_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
attention_dropout=attn_drop_rate,
|
||||
dropout=drop_rate,
|
||||
dtype=dtype,
|
||||
checkpoint=checkpoint,
|
||||
activation=nn.functional.gelu,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
if last:
|
||||
norm = col_nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
|
||||
# head = col_gpt.GPTLMHead(vocab_size=vocab_size,
|
||||
@@ -215,7 +230,7 @@ def _filter_kwargs(func, kwargs):
|
||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
|
||||
|
||||
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
logger = get_dist_logger()
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
@@ -233,10 +248,10 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to
|
||||
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
|
||||
models = []
|
||||
for start, end in parts:
|
||||
kwargs['num_layers'] = end - start
|
||||
kwargs['first'] = start == 0
|
||||
kwargs['last'] = end == num_layers
|
||||
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
|
||||
kwargs["num_layers"] = end - start
|
||||
kwargs["first"] = start == 0
|
||||
kwargs["last"] = end == num_layers
|
||||
logger.info(f"Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers")
|
||||
chunk = module_cls(**_filter_kwargs(module_cls.__init__, kwargs)).to(device)
|
||||
|
||||
if wrapper is not None:
|
||||
@@ -253,70 +268,82 @@ def _build_generic_gpt_pipeline_1d(module_cls, num_layers, num_chunks, device=to
|
||||
numel = 0
|
||||
for _, param in model.named_parameters(recurse=True):
|
||||
numel += param.numel()
|
||||
logger.info(f'Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB')
|
||||
logger.info(f"Rank{rank}/{pipeline_rank} model size = {numel * 2 / 1e9} GB")
|
||||
return model
|
||||
|
||||
|
||||
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device('cuda'), fused=False, **kwargs):
|
||||
def _build_gpt_pipeline_1d(num_layers, num_chunks, device=torch.device("cuda"), fused=False, **kwargs):
|
||||
model = FusedPipelineGPT1D if fused else PipelineGPT1D
|
||||
return _build_generic_gpt_pipeline_1d(model, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
|
||||
def _build_gpt_pipeline_hybrid(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||
return _build_generic_gpt_pipeline_1d(PipelineGPTHybrid, num_layers, num_chunks, device, **kwargs)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_1d(12, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_1d(48, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_1D(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False, fused=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_1d(96, num_chunks, fused=fused, **cfg)
|
||||
|
||||
|
||||
def GPT2_exlarge_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=1600,
|
||||
num_attention_heads=32,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_hybrid(48, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT2_small_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=768,
|
||||
num_attention_heads=12,
|
||||
checkpoint=checkpoint,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_hybrid(12, num_chunks, **cfg)
|
||||
|
||||
|
||||
def GPT3_pipeline_hybrid(num_chunks=1, checkpoint=False, dtype=torch.float, embed_split_hidden=False):
|
||||
cfg = dict(hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden)
|
||||
cfg = dict(
|
||||
hidden_size=12288,
|
||||
num_attention_heads=96,
|
||||
checkpoint=checkpoint,
|
||||
max_position_embeddings=2048,
|
||||
dtype=dtype,
|
||||
embed_split_hidden=embed_split_hidden,
|
||||
)
|
||||
return _build_gpt_pipeline_hybrid(96, num_chunks, **cfg)
|
||||
|
Reference in New Issue
Block a user