[Tensor] add embedding tp1d row (#904)

This commit is contained in:
Ziyue Jiang
2022-04-29 14:10:05 +08:00
committed by GitHub
parent 16122d5fac
commit f593a5637e
5 changed files with 108 additions and 8 deletions

View File

@@ -9,7 +9,7 @@ from packaging import version
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table)
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding)
if not input_tensor.is_gathered():
@@ -25,6 +25,37 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
output.gather()
return output
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Embedding)
if not input_tensor.is_gathered():
input_tensor.gather()
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
num_embeddings_per_partition = weight.size(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# Build the mask.
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
(input_tensor.torch_tensor() >= vocab_end_index)
# Mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
masked_input[input_mask] = 0
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(),
*args, **kwargs)
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(output)
return output
@colo_op_impl(torch.nn.functional.embedding)
def colo_embedding(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
@@ -48,7 +79,9 @@ def colo_embedding(types, args, kwargs, pg):
return ColoTensor.init_from_torch_tensor(output)
elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.shard_spec.compute_patterns
if ComputePattern.TP1DCol_Embedding in compute_patterns:
if ComputePattern.TP1DRow_Embedding in compute_patterns:
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
elif ComputePattern.TP1DCol_Embedding in compute_patterns:
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
else:
raise NotImplementedError

View File

@@ -166,6 +166,7 @@ class ColoTensor(object):
dim = -1
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
self._shard_pattern = ShardPattern.NA
self._size = self._torch_tensor.size()
def is_gathered(self) -> bool:
return self._shard_pattern == ShardPattern.NA