diff --git a/colossalai/tensor/_ops/addmm.py b/colossalai/tensor/_ops/addmm.py index 8b9d04c8e..bcfdd72ae 100644 --- a/colossalai/tensor/_ops/addmm.py +++ b/colossalai/tensor/_ops/addmm.py @@ -3,12 +3,12 @@ from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor from colossalai.tensor import distspec +from colossalai.context import ParallelMode from ._utils import GeneralTensor, Number, convert_to_colo_tensor def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, alpha: Number) -> ColoTensor: - parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D) # mat1:S[1] x mat2:S[0] = Output:P # beta * input + alpha * All-Reduce(Output) = res @@ -18,7 +18,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso # Output:P partial_output = torch.mm(mat1, mat2) # Reduce(Output) - output = reduce_input(partial_output, parallel_action.parallel_mode) + output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) # input assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op' output = beta * input_tensor + alpha * output @@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, alpha: Number) -> ColoTensor: # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] - parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D) + parallel_action = mat2.spec.parallel_action mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group())) - mat1 = reduce_grad(mat1, parallel_action.parallel_mode) + mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]), - [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]) + ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) if parallel_action.gather_out: # All-Gather(Output) diff --git a/colossalai/tensor/_ops/embedding.py b/colossalai/tensor/_ops/embedding.py index eae6a1ef1..25b33e95a 100644 --- a/colossalai/tensor/_ops/embedding.py +++ b/colossalai/tensor/_ops/embedding.py @@ -1,10 +1,10 @@ -import torch import torch.nn.functional as F from typing import Optional from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.nn.layer.parallel_1d._utils import reduce_input from colossalai.core import global_context as gpc from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec +from colossalai.context import ParallelMode from ._utils import GeneralTensor, convert_to_colo_tensor @@ -17,7 +17,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, sparse: bool = False) -> ColoTensor: # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # Gather splitted lookup table - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) output_parallel = F.embedding(input_tensor, @@ -29,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, sparse=sparse) output_spec = TensorSpec( distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), - [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]) + ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) return output @@ -45,10 +44,9 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) # Find index in this shard and mask those not here # Reduce all - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) - tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode) + tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) num_embeddings_per_partition = weight.size(0) vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition vocab_end_index = vocab_start_index + num_embeddings_per_partition @@ -72,7 +70,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, # Mask the output embedding. partial_output[input_mask, :] = 0. # Reduce across all the model parallel GPUs. - output = reduce_input(partial_output, parallel_action.parallel_mode) + output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group()))) return output diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 1bd6441d8..21c5ff280 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -1,16 +1,14 @@ -import torch import torch.nn.functional as F -import torch.distributed as dist from typing import Optional from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv +from colossalai.context import ParallelMode from ._utils import GeneralTensor, convert_to_colo_tensor def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res # Input:S[1] @@ -20,7 +18,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option # Output:P partial_output = F.linear(input_tensor, weight) # Reduce(Output) - output = reduce_input(partial_output, parallel_action.parallel_mode) + output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) # Bias if bias is not None: assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op' @@ -34,15 +32,16 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # All-Gather(Output) # Input:B - parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) + parallel_action = weight.spec.parallel_action input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) - input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode) + input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D) output_parallel = F.linear(input_parallel, weight, bias) - output = ColoTensor.from_torch_tensor( - output_parallel, - spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), - [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])) + output = ColoTensor.from_torch_tensor(output_parallel, + spec=TensorSpec( + distspec.shard(weight.spec.get_process_group(), [-1], + [weight.spec.get_process_group_size()]), + ParallelAction(ComputePattern.TP1D))) if parallel_action.gather_out: # All-Gather(Output) output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) diff --git a/colossalai/tensor/_ops/loss.py b/colossalai/tensor/_ops/loss.py index 1a41e36a9..cf4468c43 100644 --- a/colossalai/tensor/_ops/loss.py +++ b/colossalai/tensor/_ops/loss.py @@ -28,7 +28,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor, reduction=reduction, label_smoothing=label_smoothing) return ColoTensor.from_torch_tensor(output) - elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied + elif input_tensor.has_spec(): # Single Model Parallel Applied if input_tensor.spec.is_1D_col(): output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) return ColoTensor.from_torch_tensor(output) diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index e9f144d9e..8c09f088c 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -33,3 +33,6 @@ class ColoParameter(ColoTensor): tensor = tensor.as_subclass(ColoParameter) tensor.__init__(tensor, requires_grad=requires_grad, spec=spec) return tensor + + def __repr__(self): + return f'ColoParameter: {torch.Tensor.__repr__(self)}' diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 6d78e2cbd..08685e74b 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -45,7 +45,7 @@ class ColoTensor(torch.Tensor): self._spec = spec def has_spec(self) -> bool: - return self._spec.num_action > 0 + return self._spec.parallel_action is not None def is_model_data(self) -> bool: return self._type == TensorType.MODEL diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index c75eef3be..40a4a2c51 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -1,26 +1,21 @@ import torch.distributed as dist from enum import Enum -from typing import List -from colossalai.context.parallel_mode import ParallelMode +from typing import List, Optional from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern class ComputePattern(Enum): TP1D = 0 - ZeRO = 1 - DP = 2 + TP2D = 1 + TP2P5D = 2 + TP3D = 3 class ParallelAction(object): - def __init__(self, - priority=0, - compute_pattern=ComputePattern.DP, - parallel_mode=ParallelMode.DATA, - gather_out=True) -> None: - self.priority = priority + def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None: + assert isinstance(compute_pattern, ComputePattern) self.compute_pattern = compute_pattern - self.parallel_mode = parallel_mode self.gather_out = gather_out @@ -48,32 +43,9 @@ class TensorSpec(object): # We perform Linear Op according to compute pattern of TP1D_Linear. # After Linear Op, we split the tensors according to ZeRO. - def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []): - self._parallel_action_list = parallel_action_list + def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None): + self.parallel_action = parallel_action self.dist_spec = dist_spec - self.sort() - - @property - def parallel_action_list(self): - return self._parallel_action_list - - @property - def num_action(self): - return len(self._parallel_action_list) - - @property - def compute_patterns(self): - return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list] - - def sort(self): - if len(self._parallel_action_list) > 0: - self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority) - - def get_action_by_compute_pattern(self, compute_pattern: ComputePattern): - for parallel_action in self._parallel_action_list: - if parallel_action.compute_pattern == compute_pattern: - return parallel_action - return None def get_process_group(self): return self.dist_spec.process_group @@ -99,4 +71,4 @@ class TensorSpec(object): and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 def has_compute_pattern(self, compute_pattern: ComputePattern): - return self.get_action_by_compute_pattern(compute_pattern) is not None + return self.parallel_action.compute_pattern == compute_pattern diff --git a/tests/test_tensor/test_addmm_tp.py b/tests/test_tensor/test_addmm_tp.py index b02f4baad..8b68d5cd6 100644 --- a/tests/test_tensor/test_addmm_tp.py +++ b/tests/test_tensor/test_addmm_tp.py @@ -41,7 +41,7 @@ class Conv1D(nn.Module): def init_1d_row(weight, bias): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -49,7 +49,7 @@ def init_1d_row(weight, bias): def init_1d_col(weight, bias): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) bias.set_spec(spec) diff --git a/tests/test_tensor/test_embedding_tp.py b/tests/test_tensor/test_embedding_tp.py index 71d0c52bc..8954367c9 100644 --- a/tests/test_tensor/test_embedding_tp.py +++ b/tests/test_tensor/test_embedding_tp.py @@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal def init_1d_row(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -26,7 +26,7 @@ def init_1d_row(weight): def init_1d_col(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) diff --git a/tests/test_tensor/test_gpt.py b/tests/test_tensor/test_gpt.py index 9e1671280..781e36c25 100644 --- a/tests/test_tensor/test_gpt.py +++ b/tests/test_tensor/test_gpt.py @@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs def init_1d_row_spec(model): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if 'weight' in n and 'ln' not in n: @@ -26,7 +26,7 @@ def init_1d_row_spec(model): def init_1d_col_spec(model): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if 'ln' not in n and ('weight' in n or 'bias' in n): diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index ac9a8ece0..f673687ea 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal def init_1d_row(weight, bias): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -27,7 +27,7 @@ def init_1d_row(weight, bias): def init_1d_col(weight, bias): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) bias.set_spec(spec) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 6a71242df..c9e3da884 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -20,19 +20,15 @@ from _utils import set_seed def init_1d_row_linear(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) def init_1d_col_linear(weight, gather_out=True): spec = TensorSpec( - distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [ - ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1D, - parallel_mode=ParallelMode.PARALLEL_1D, - gather_out=gather_out) - ]) + distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), + ParallelAction(ComputePattern.TP1D, gather_out=gather_out)) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True): def init_1d_row_embedding(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec) @@ -48,7 +44,7 @@ def init_1d_row_embedding(weight): def init_1d_col_embedding(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), - [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) + ParallelAction(ComputePattern.TP1D)) with DistSpecManager.no_grad(): weight.set_spec(spec)