[shardformer] add Dropout layer support different dropout pattern (#3856)

* add dropout layer, add dropout test

* modify seed manager as context manager

* add a copy of col_nn.layer

* add dist_crossentropy loss; separate module test

* polish the code

* fix dist crossentropy loss
This commit is contained in:
FoolPlayer
2023-06-01 16:21:02 +08:00
committed by Frank Lee
parent c594dc2f1c
commit ab8a47f830
14 changed files with 1413 additions and 41 deletions

View File

@@ -73,7 +73,6 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
total_input = input
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1], total_input.shape[2])

View File

@@ -469,8 +469,7 @@ class Linear1D_Col(ParallelLayer):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# self.out_features_per_partition = divide(out_features*2, gpc.tensor_parallel_size)
self.out_features_per_partition = out_features
self.out_features_per_partition = divide(out_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
@@ -613,8 +612,7 @@ class Linear1D_Row(ParallelLayer):
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
# self.input_size_per_partition = divide(in_features*2, gpc.tensor_parallel_size)
self.input_size_per_partition = in_features
self.input_size_per_partition = divide(in_features, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
@@ -886,8 +884,7 @@ class VocabParallelEmbedding1D(ParallelLayer):
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings_per_partition = num_embeddings
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition