[tensor] derive compute pattern from dist spec (#971)

* derive compute pattern from dist spec

* polish code
This commit is contained in:
ver217
2022-05-16 14:58:08 +08:00
committed by GitHub
parent 46bc95708f
commit c2fdc6a011
10 changed files with 79 additions and 65 deletions

View File

@@ -11,7 +11,7 @@ from colossalai.tensor import dist_spec
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
@@ -32,7 +32,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float],
alpha: Union[int, float]) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode)
@@ -71,16 +71,16 @@ def colo_addmm(types, args, kwargs, pg):
# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied
assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op'
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.init_from_torch_tensor(
torch.addbmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
elif mat2.spec.num_action == 1: # Single Model Parallel Applied
torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha))
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group()))
mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec)
compute_patterns = mat2.spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha)
elif ComputePattern.TP1DCol in compute_patterns:
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha)
else:
raise NotImplementedError

View File

@@ -12,7 +12,7 @@ from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, Parall
def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs)
@@ -28,7 +28,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode)
@@ -71,16 +71,17 @@ def colo_embedding(types, args, kwargs, pg):
weight = ColoTensor.init_from_torch_tensor(weight)
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs)
return ColoTensor.init_from_torch_tensor(output)
elif weight.spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row():
return colo_embedding_1Drow(input_tensor, weight, args, kwargs)
elif ComputePattern.TP1DCol in compute_patterns:
elif weight.spec.is_1D_col():
return colo_embedding_1Dcol(input_tensor, weight, args, kwargs)
else:
raise NotImplementedError

View File

@@ -9,7 +9,7 @@ from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor:
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
@@ -33,11 +33,12 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol)
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode)
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor())
if bias is not None:
bias = bias.torch_tensor()
output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias)
output = ColoTensor.init_from_torch_tensor(
output_parallel,
@@ -83,16 +84,17 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
assert not bias.has_spec(), 'Invalid bias spec for native Linear op'
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
input_tensor = input_tensor.torch_tensor()
weight = weight.torch_tensor()
bias = bias.torch_tensor()
if bias is not None:
bias = bias.torch_tensor()
ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias))
elif weight.spec.num_action == 1: # Single Model Parallel Applied
compute_patterns = weight.spec.compute_patterns
if ComputePattern.TP1DRow in compute_patterns:
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
ret_tensor = colo_linear_1Drow(input_tensor, weight, bias)
elif ComputePattern.TP1DCol in compute_patterns:
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias)
else:
raise NotImplementedError

View File

@@ -4,6 +4,7 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor import ColoTensor
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
@colo_op_impl(torch.nn.functional.cross_entropy)
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
arg_num = len(args)
@@ -27,13 +28,13 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
if isinstance(target, ColoTensor):
target = target.torch_tensor()
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
input_tensor.torch_tensor(), target, weight))
if input_tensor.spec.is_gathered(): # Input is gathered
return ColoTensor.init_from_torch_tensor(
torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight))
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
if input_tensor.spec.is_1Dcol():
return ColoTensor.init_from_torch_tensor(
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
if input_tensor.spec.is_1D_col():
return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(),
target))
else:
raise NotImplementedError
else:

View File

@@ -1,6 +1,7 @@
from enum import Enum
from torch.distributed import ProcessGroup
from typing import Optional, List
from numpy import prod
__all__ = ['replicate', 'shard']
@@ -39,4 +40,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
assert process_group is not None
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.size()
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@@ -5,17 +5,9 @@ from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
class ComputePattern(Enum):
# TODO (ver217): remove TP1DRow_<ops>
TP1DRow = 0
TP1DCol = 9
TP1DRow_Linear = 1
TP1DCol_Linear = 2
TP1DRow_Embedding = 3
TP1DCol_Embedding = 4
TP1DRow_mm = 5
TP1DCol_mm = 6
ZeRO = 7
DP = 8
TP1D = 0
ZeRO = 1
DP = 2
class ParallelAction(object):
@@ -45,14 +37,14 @@ class TensorSpec(object):
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1DRow_Linear.
# then splitting tensor according to ParallelAction of TP1D_Linear.
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow_Linear.
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []):
@@ -90,10 +82,17 @@ class TensorSpec(object):
def is_gathered(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.size() == 1)
def is_1Dcol(self):
def is_1D_col(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_1D_row(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.get_action_by_compute_pattern(compute_pattern) is not None