[tensor] refactor colo-tensor (#992)

* refactor colo-tensor and update linear op

* polish code

* polish code

* update ops and unit tests

* update unit tests

* polish code

* rename dist_spec module

* polish code

* polish code

* remove unneeded import

* fix pipelinable
This commit is contained in:
ver217
2022-05-19 12:44:59 +08:00
committed by GitHub
parent 1467d83edf
commit ad536e308e
27 changed files with 657 additions and 616 deletions

View File

@@ -1,31 +1,52 @@
import torch
import torch.nn.functional as F
from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, dist_spec
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
def colo_embedding_1Dcol(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
output_parallel = F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = TensorSpec(
dist_spec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group().size()]),
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])
output = ColoTensor.init_from_torch_tensor(output_parallel, spec=output_spec)
output.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output
def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
def colo_embedding_1Drow(input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
num_embeddings_per_partition = weight.size(0)
@@ -33,53 +54,87 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# Build the mask.
input_mask = (input_tensor.torch_tensor() < vocab_start_index) | \
(input_tensor.torch_tensor() >= vocab_end_index)
input_mask = (input_tensor < vocab_start_index) | \
(input_tensor >= vocab_end_index)
# Mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.torch_tensor().clone() - vocab_start_index
masked_input = input_tensor.clone() - vocab_start_index
masked_input[input_mask] = 0
partial_output = torch.nn.functional.embedding(masked_input, weight.torch_tensor(), *args, **kwargs)
partial_output = F.embedding(masked_input,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
# Mask the output embedding.
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode)
output = ColoTensor.init_from_torch_tensor(output,
spec=TensorSpec(dist_spec.replicate(weight.spec.get_process_group())))
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output
@colo_op_impl(torch.nn.functional.embedding)
def colo_embedding(types, args, kwargs, pg):
def colo_embedding_1d(mode: str,
input_tensor: ColoTensor,
weight: ColoTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
assert mode in ('row', 'col')
funcs = {'row': colo_embedding_1Drow, 'col': colo_embedding_1Dcol}
return funcs[mode](input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
@colo_op_impl(F.embedding)
def colo_embedding(input_tensor: GeneralTensor,
weight: GeneralTensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method looks up an embedding table.
"""
input_tensor = args[0]
weight = args[1]
args = args[2:]
if not isinstance(input_tensor, ColoTensor):
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
if not isinstance(weight, ColoTensor):
weight = ColoTensor.init_from_torch_tensor(weight)
input_tensor, weight = tuple(map(convert_to_colo_tensor, (input_tensor, weight)))
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
return ColoTensor.init_from_torch_tensor(output)
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row():
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
mode = 'row'
elif weight.spec.is_1D_col():
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
mode = 'col'
else:
raise NotImplementedError
return colo_embedding_1d(mode,
input_tensor,
weight,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
else:
raise NotImplementedError