diff --git a/colossalai/tensor/_ops/addmm.py b/colossalai/tensor/_ops/addmm.py index c45b85e3a..30f01d2cd 100644 --- a/colossalai/tensor/_ops/addmm.py +++ b/colossalai/tensor/_ops/addmm.py @@ -11,7 +11,7 @@ from colossalai.tensor import dist_spec def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float], alpha: Union[int, float]) -> ColoTensor: - parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D) # mat1:S[1] x mat2:S[0] = Output:P # beta * input + alpha * All-Reduce(Output) = res @@ -32,7 +32,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float], alpha: Union[int, float]) -> ColoTensor: # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] - parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) + parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D) mat1.to_dist_spec(dist_spec.replicate(mat2.spec.get_process_group())) mat1_torch_tensor = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode) @@ -71,16 +71,16 @@ def colo_addmm(types, args, kwargs, pg): # Add communication logic before and after linear call. ret_tensor = None if not mat2.has_spec(): # No Model Parallel Applied - assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op' + assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op' + assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op' ret_tensor = ColoTensor.init_from_torch_tensor( - torch.addbmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha)) - elif mat2.spec.num_action == 1: # Single Model Parallel Applied + torch.addmm(input_tensor.torch_tensor(), mat1, mat2.torch_tensor(), beta=beta, alpha=alpha)) + elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied spec = TensorSpec(dist_spec.replicate(mat2.spec.get_process_group())) mat1 = args[1] if isinstance(args[1], ColoTensor) else ColoTensor.init_from_torch_tensor(args[1], spec=spec) - compute_patterns = mat2.spec.compute_patterns - if ComputePattern.TP1DRow in compute_patterns: + if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered(): ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha) - elif ComputePattern.TP1DCol in compute_patterns: + elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()): ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha) else: raise NotImplementedError diff --git a/colossalai/tensor/_ops/embedding.py b/colossalai/tensor/_ops/embedding.py index 40404278c..308794f98 100644 --- a/colossalai/tensor/_ops/embedding.py +++ b/colossalai/tensor/_ops/embedding.py @@ -12,7 +12,7 @@ from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, Parall def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor: # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # Gather splitted lookup table - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) + parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group())) output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), *args, **kwargs) @@ -28,7 +28,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) # Find index in this shard and mask those not here # Reduce all - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group())) tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode) @@ -71,16 +71,17 @@ def colo_embedding(types, args, kwargs, pg): weight = ColoTensor.init_from_torch_tensor(weight) # Handle differen parallel actions. + if not weight.has_spec(): # No Model Parallel Applied + assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op' input_tensor = input_tensor.torch_tensor() weight = weight.torch_tensor() output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs) return ColoTensor.init_from_torch_tensor(output) - elif weight.spec.num_action == 1: # Single Model Parallel Applied - compute_patterns = weight.spec.compute_patterns - if ComputePattern.TP1DRow in compute_patterns: + elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.spec.is_1D_row(): return colo_embedding_1Drow(input_tensor, weight, args, kwargs) - elif ComputePattern.TP1DCol in compute_patterns: + elif weight.spec.is_1D_col(): return colo_embedding_1Dcol(input_tensor, weight, args, kwargs) else: raise NotImplementedError diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 8bc6c3ee7..0b1128e87 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -9,7 +9,7 @@ from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor: - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) + parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res # Input:S[1] @@ -33,11 +33,12 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # All-Gather(Output) # Input:B - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) + parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) input_tensor.to_dist_spec(dist_spec.replicate(weight.spec.get_process_group())) input_parallel = reduce_grad(input_tensor.torch_tensor(), parallel_action.parallel_mode) - - output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias.torch_tensor()) + if bias is not None: + bias = bias.torch_tensor() + output_parallel = torch.nn.functional.linear(input_parallel, weight.torch_tensor(), bias) output = ColoTensor.init_from_torch_tensor( output_parallel, @@ -83,16 +84,17 @@ def colo_linear(types, args, kwargs, pg): # Add communication logic before and after linear call. ret_tensor = None if not weight.has_spec(): # No Model Parallel Applied - assert not bias.has_spec(), 'Invalid bias spec for native Linear op' + assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op' + assert bias.spec.is_gathered(), 'Invalid bias spec for native Linear op' input_tensor = input_tensor.torch_tensor() weight = weight.torch_tensor() - bias = bias.torch_tensor() + if bias is not None: + bias = bias.torch_tensor() ret_tensor = ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) - elif weight.spec.num_action == 1: # Single Model Parallel Applied - compute_patterns = weight.spec.compute_patterns - if ComputePattern.TP1DRow in compute_patterns: + elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()): ret_tensor = colo_linear_1Drow(input_tensor, weight, bias) - elif ComputePattern.TP1DCol in compute_patterns: + elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()): ret_tensor = colo_linear_1Dcol(input_tensor, weight, bias) else: raise NotImplementedError diff --git a/colossalai/tensor/_ops/loss.py b/colossalai/tensor/_ops/loss.py index 6243301fd..8e343ee21 100644 --- a/colossalai/tensor/_ops/loss.py +++ b/colossalai/tensor/_ops/loss.py @@ -4,6 +4,7 @@ from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor import ColoTensor from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D + @colo_op_impl(torch.nn.functional.cross_entropy) def colo_cross_entropy(types, args=(), kwargs=None, pg=None): arg_num = len(args) @@ -27,13 +28,13 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None): if isinstance(target, ColoTensor): target = target.torch_tensor() - if input_tensor.spec.is_gathered(): # Input is gathered - return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy( - input_tensor.torch_tensor(), target, weight)) + if input_tensor.spec.is_gathered(): # Input is gathered + return ColoTensor.init_from_torch_tensor( + torch.nn.functional.cross_entropy(input_tensor.torch_tensor(), target, weight)) elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied - if input_tensor.spec.is_1Dcol(): - return ColoTensor.init_from_torch_tensor( - VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target)) + if input_tensor.spec.is_1D_col(): + return ColoTensor.init_from_torch_tensor(VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), + target)) else: raise NotImplementedError else: diff --git a/colossalai/tensor/dist_spec.py b/colossalai/tensor/dist_spec.py index bad02922e..905bf975e 100644 --- a/colossalai/tensor/dist_spec.py +++ b/colossalai/tensor/dist_spec.py @@ -1,6 +1,7 @@ from enum import Enum from torch.distributed import ProcessGroup from typing import Optional, List +from numpy import prod __all__ = ['replicate', 'shard'] @@ -39,4 +40,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int assert process_group is not None assert isinstance(dims, list) and isinstance(num_partitions, list) assert len(dims) == len(num_partitions) + assert prod(num_partitions) == process_group.size() return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions)) diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index d1b762f0b..97b2b7cda 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -5,17 +5,9 @@ from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern class ComputePattern(Enum): - # TODO (ver217): remove TP1DRow_ - TP1DRow = 0 - TP1DCol = 9 - TP1DRow_Linear = 1 - TP1DCol_Linear = 2 - TP1DRow_Embedding = 3 - TP1DCol_Embedding = 4 - TP1DRow_mm = 5 - TP1DCol_mm = 6 - ZeRO = 7 - DP = 8 + TP1D = 0 + ZeRO = 1 + DP = 2 class ParallelAction(object): @@ -45,14 +37,14 @@ class TensorSpec(object): # using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2. # parallel_action_list = [ # ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)), - # ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D)) + # ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D)) # ] # When the ColoTensor is initialized, # we first splitting tensor according to ParallelAction of ZeRO, - # then splitting tensor according to ParallelAction of TP1DRow_Linear. + # then splitting tensor according to ParallelAction of TP1D_Linear. # During Linear computation # Before Linear Op, we gather the tensors according to ZeRO. - # We perform Linear Op according to compute pattern of TP1DRow_Linear. + # We perform Linear Op according to compute pattern of TP1D_Linear. # After Linear Op, we split the tensors according to ZeRO. def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []): @@ -90,10 +82,17 @@ class TensorSpec(object): def is_gathered(self): return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ - or (len(self.dist_spec.num_partitions) == 1 + or (len(self.dist_spec.num_partitions) == 1 and self.dist_spec.num_partitions[0] == 1) \ or (self.dist_spec.process_group.size() == 1) - def is_1Dcol(self): + def is_1D_col(self): return self.dist_spec.placement == DistPlacementPattern.SHARD \ - and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 \ No newline at end of file + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 + + def is_1D_row(self): + return self.dist_spec.placement == DistPlacementPattern.SHARD \ + and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 + + def has_compute_pattern(self, compute_pattern: ComputePattern): + return self.get_action_by_compute_pattern(compute_pattern) is not None diff --git a/tests/test_tensor/test_addmm_tp.py b/tests/test_tensor/test_addmm_tp.py index 739566768..67ae49ea9 100644 --- a/tests/test_tensor/test_addmm_tp.py +++ b/tests/test_tensor/test_addmm_tp.py @@ -40,7 +40,7 @@ class Conv1D(nn.Module): def init_1d_row(weight, bias): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -55,7 +55,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias): def init_1d_col(weight, bias): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) bias.set_spec(spec) diff --git a/tests/test_tensor/test_embedding_tp.py b/tests/test_tensor/test_embedding_tp.py index a80b46148..946aa76b2 100644 --- a/tests/test_tensor/test_embedding_tp.py +++ b/tests/test_tensor/test_embedding_tp.py @@ -17,7 +17,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s def init_1d_row(weight): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -31,7 +31,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight): def init_1d_col(weight): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index 2d01adce7..326fe045a 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -18,7 +18,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s def init_1d_row(weight, bias): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -33,7 +33,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias): def init_1d_col(weight, bias): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) bias.set_spec(spec) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 9cebf72db..e2bcf348e 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -86,35 +86,43 @@ def set_seed(seed): torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True + def init_1d_row_linear(weight): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) + def init_1d_col_linear(weight, gather_out=True): spec = TensorSpec( - dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \ - gather_out=gather_out)]) + dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [ + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1D, + parallel_mode=ParallelMode.PARALLEL_1D, + gather_out=gather_out) + ]) with DistSpecManager.no_grad(): weight.set_spec(spec) + def init_1d_row_embedding(weight): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) + def init_1d_col_embedding(weight): spec = TensorSpec( dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)]) + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) with DistSpecManager.no_grad(): weight.set_spec(spec) + def run_1d_hybrid_tp(model_name): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable(model_name) @@ -124,7 +132,7 @@ def run_1d_hybrid_tp(model_name): set_seed(1) with ColoInitContext(device=get_current_device()): model = model_builder(checkpoint=True) - + if rank == 0: model_torch = model_builder(checkpoint=True) model_torch = model_torch.cuda() @@ -173,7 +181,7 @@ def run_1d_hybrid_tp(model_name): if rank == 0: model_torch.eval() colo_optimizer_torch.zero_grad() - + data = data.to(get_current_device()) label = label.to(get_current_device()) @@ -217,11 +225,11 @@ def run_1d_hybrid_tp(model_name): assert torch.allclose(p1, p2) else: # TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec. - if p1.size(-1) < p2.size(-1): # col + if p1.size(-1) < p2.size(-1): # col world_size = p2.size(-1) // p1.size(-1) split_p2 = torch.chunk(p2, world_size, dim=-1)[0] - - elif p1.size(0) < p2.size(0): # row + + elif p1.size(0) < p2.size(0): # row world_size = p2.size(0) // p1.size(0) split_p2 = torch.chunk(p2, world_size, dim=0)[0] @@ -376,7 +384,7 @@ def _run_pretrain_load(): if isinstance(param, ColoParameter): c1 += 1 else: - c2 +=1 + c2 += 1 dict_col[name] = param assert c_ref == c1 assert c2 == 0 @@ -395,6 +403,7 @@ def run_model_dist(rank, world_size, port): for name in ['bert', 'simple_net']: run_1d_hybrid_tp(name) + @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 4]) # @parameterize('world_size', [1, 4])