[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -1,7 +1,6 @@
import torch
import torch.nn.init as init
from torch import Tensor
from torch import distributed as dist
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
@@ -12,7 +11,7 @@ from colossalai.legacy.nn.layer.base_layer import ParallelLayer
from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_backward, reduce_grad, reduce_input
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
from colossalai.legacy.nn.layer.utils import divide
from colossalai.legacy.registry import LAYERS, LOSSES, MODELS
from colossalai.legacy.registry import LAYERS, LOSSES
from colossalai.utils import get_current_device
@@ -30,13 +29,9 @@ class VocabParallelEmbedding(torch.nn.Module):
will ignore this embedding
"""
def __init__(self,
hidden_size,
vocab_size,
max_sequence_length,
embedding_dropout_prob,
num_tokentypes=0,
dtype=torch.float):
def __init__(
self, hidden_size, vocab_size, max_sequence_length, embedding_dropout_prob, num_tokentypes=0, dtype=torch.float
):
super(VocabParallelEmbedding, self).__init__()
self.hidden_size = hidden_size
@@ -44,11 +39,11 @@ class VocabParallelEmbedding(torch.nn.Module):
# Word embeddings (parallel).
self.word_embeddings = VocabParallelEmbedding1D(vocab_size, self.hidden_size, dtype=dtype)
self._word_embeddings_key = 'word_embeddings'
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size, dtype=dtype)
self._position_embeddings_key = 'position_embeddings'
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
# self.init_method(self.position_embeddings.weight)
@@ -56,7 +51,7 @@ class VocabParallelEmbedding(torch.nn.Module):
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size, dtype=dtype)
# Initialize the token-type embeddings.
@@ -83,9 +78,9 @@ class VocabParallelEmbedding(torch.nn.Module):
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
@@ -112,19 +107,16 @@ class VocabParallelEmbedding(torch.nn.Module):
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
@@ -138,9 +130,8 @@ class VocabParallelEmbedding(torch.nn.Module):
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
if "word_embeddings" in key:
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
@@ -150,9 +141,8 @@ class VocabParallelEmbedding(torch.nn.Module):
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
if "position_embeddings" in key:
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
@@ -163,15 +153,14 @@ class VocabParallelEmbedding(torch.nn.Module):
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if "tokentype_embeddings" in key:
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it',
flush=True)
print(
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
)
class VocabParallelEmbedding1D(torch.nn.Module):
@@ -193,37 +182,41 @@ class VocabParallelEmbedding1D(torch.nn.Module):
# Set the details for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = gpc.tensor_parallel_size
# Divide the weight matrix along the vocabulary dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, gpc.get_local_rank(ParallelMode.PARALLEL_1D), self.tensor_model_parallel_size
)
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
# Allocate weights and initialize.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
init.uniform_(self.weight, -1, 1)
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
@@ -234,7 +227,6 @@ class VocabParallelEmbedding1D(torch.nn.Module):
@LOSSES.register_module
class vocab_parallel_cross_entropy(nn.Module):
def __init__(self):
super().__init__()
@@ -242,20 +234,19 @@ class vocab_parallel_cross_entropy(nn.Module):
"""Helper function for the cross entropy."""
vocab_parallel_logits = vocab_parallel_logits[..., :-1, :].contiguous()
target = target[..., 1:].contiguous()
return _VocabParallelCrossEntropy.apply(vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)),
target.view(-1))
return _VocabParallelCrossEntropy.apply(
vocab_parallel_logits.view(-1, vocab_parallel_logits.size(-1)), target.view(-1)
)
class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.all_reduce(
logits_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.PARALLEL_1D)
)
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
@@ -282,17 +273,17 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.all_reduce(
predicted_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D)
)
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
torch.distributed.all_reduce(
sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.PARALLEL_1D)
)
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
@@ -304,7 +295,6 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_output):
# Retrieve tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
@@ -316,7 +306,7 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
grad_2d[arange_1d, masked_target_1d] -= 1.0 - target_mask.view(-1).float()
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
@@ -326,8 +316,8 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
class VocabUtility:
"""Split the vocabulary into `world_size` chunks amd return the
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
first and last index of the vocabulary belonging to the `rank`
partition: Note that indices in [fist, last)"""
@staticmethod
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank, world_size):
@@ -393,11 +383,11 @@ class HiddenParallelEmbedding(torch.nn.Module):
# Word embeddings (parallel).
self.word_embeddings = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
self._word_embeddings_key = 'word_embeddings'
self._word_embeddings_key = "word_embeddings"
# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(max_sequence_length, self.hidden_size)
self._position_embeddings_key = 'position_embeddings'
self._position_embeddings_key = "position_embeddings"
# Initialize the position embeddings.
# self.init_method(self.position_embeddings.weight)
@@ -405,7 +395,7 @@ class HiddenParallelEmbedding(torch.nn.Module):
# Add this as an optional field that can be added through
# method call so we can load a pretrain model without
# token types and add them as needed.
self._tokentype_embeddings_key = 'tokentype_embeddings'
self._tokentype_embeddings_key = "tokentype_embeddings"
if self.num_tokentypes > 0:
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
@@ -432,9 +422,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
This allows us to load the model normally and then add this embedding.
"""
if self.tokentype_embeddings is not None:
raise Exception('tokentype embeddings is already initialized')
raise Exception("tokentype embeddings is already initialized")
if torch.distributed.get_rank() == 0:
print('adding embedding for {} tokentypes'.format(num_tokentypes), flush=True)
print("adding embedding for {} tokentypes".format(num_tokentypes), flush=True)
self.num_tokentypes = num_tokentypes
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, self.hidden_size)
# Initialize the token-type embeddings.
@@ -460,19 +450,16 @@ class HiddenParallelEmbedding(torch.nn.Module):
embeddings = self.embedding_dropout(embeddings)
return embeddings
def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False):
def state_dict_for_save_checkpoint(self, destination=None, prefix="", keep_vars=False):
"""For easy load."""
state_dict_ = {}
state_dict_[self._word_embeddings_key] \
= self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] \
= self.position_embeddings.state_dict(
destination, prefix, keep_vars)
state_dict_[self._word_embeddings_key] = self.word_embeddings.state_dict(destination, prefix, keep_vars)
state_dict_[self._position_embeddings_key] = self.position_embeddings.state_dict(destination, prefix, keep_vars)
if self.num_tokentypes > 0:
state_dict_[self._tokentype_embeddings_key] \
= self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars)
state_dict_[self._tokentype_embeddings_key] = self.tokentype_embeddings.state_dict(
destination, prefix, keep_vars
)
return state_dict_
@@ -486,9 +473,8 @@ class HiddenParallelEmbedding(torch.nn.Module):
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
if "word_embeddings" in key:
state_dict_[key.split("word_embeddings.")[1]] = state_dict[key]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)
# Position embedding.
@@ -498,9 +484,8 @@ class HiddenParallelEmbedding(torch.nn.Module):
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'position_embeddings' in key:
state_dict_[key.split('position_embeddings.')[1]] \
= state_dict[key]
if "position_embeddings" in key:
state_dict_[key.split("position_embeddings.")[1]] = state_dict[key]
self.position_embeddings.load_state_dict(state_dict_, strict=strict)
# Tokentype embedding.
@@ -511,15 +496,14 @@ class HiddenParallelEmbedding(torch.nn.Module):
else:
# for backward compatibility.
for key in state_dict.keys():
if 'tokentype_embeddings' in key:
state_dict_[key.split('tokentype_embeddings.')[1]] \
= state_dict[key]
if "tokentype_embeddings" in key:
state_dict_[key.split("tokentype_embeddings.")[1]] = state_dict[key]
if len(state_dict_.keys()) > 0:
self.tokentype_embeddings.load_state_dict(state_dict_, strict=strict)
else:
print('***WARNING*** expected tokentype embeddings in the '
'checkpoint but could not find it',
flush=True)
print(
"***WARNING*** expected tokentype embeddings in the " "checkpoint but could not find it", flush=True
)
class HiddenParallelEmbedding1D(torch.nn.Module):
@@ -542,21 +526,21 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
# Set the details for compatibility.
self.padding_idx = padding_idx
self.max_norm = None
self.norm_type = 2.
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
# Allocate weights and initialize.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
init.uniform_(self.weight, -1, 1)
def forward(self, input_):
# Get the embeddings.
output_parallel = F.embedding(input_, self.weight, self.padding_idx, self.max_norm, self.norm_type,
self.scale_grad_by_freq, self.sparse)
output_parallel = F.embedding(
input_, self.weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse
)
# Reduce across all the model parallel GPUs.
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
@@ -584,11 +568,9 @@ class HiddenParallelGPTLMHead1D(ParallelLayer):
# self.embedding = HiddenParallelEmbedding1D(vocab_size, hidden_size, dtype, padding_idx)
# (hidden_size/q, vocab_size)
self.synced_embed = False
self.head = Linear1D_Row(in_features=embed_dim,
out_features=vocab_size,
bias=False,
dtype=dtype,
parallel_input=False)
self.head = Linear1D_Row(
in_features=embed_dim, out_features=vocab_size, bias=False, dtype=dtype, parallel_input=False
)
def forward(self, x: Tensor) -> Tensor:
if self.synced_embed: