diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index 39b279c01..0e2b8169d 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -2,3 +2,4 @@ from .linear import colo_linear from .element_wise import * from .layernorm import colo_layernorm from .loss import colo_cross_entropy +from .embedding import colo_embedding \ No newline at end of file diff --git a/colossalai/tensor/_ops/embedding.py b/colossalai/tensor/_ops/embedding.py new file mode 100644 index 000000000..84c95492f --- /dev/null +++ b/colossalai/tensor/_ops/embedding.py @@ -0,0 +1,56 @@ +import torch +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.context import ParallelMode +from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, \ + gather_forward_split_backward, reduce_grad +from colossalai.nn.layer.utils import divide +from colossalai.core import global_context as gpc +from packaging import version +from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern + +def colo_embedding_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, args, kwargs) -> ColoTensor: + # embedding_1Dcol split the weight(lookup table) + # Gather splitted lookup table + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Embedding) + if not input_tensor.is_gathered(): + input_tensor.gather() + + output_parallel = torch.nn.functional.embedding(input_tensor.torch_tensor(), weight.torch_tensor(), + *args, **kwargs) + output = ColoTensor.init_from_torch_tensor(output_parallel) + out_parallel_action_list = [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)] + output_spec = TensorSpec(out_parallel_action_list) + output.set_spec(output_spec, shard=False) + output.set_shard_pattern(ShardPattern.Col) + output.gather() + return output + +@colo_op_impl(torch.nn.functional.embedding) +def colo_embedding(types, args, kwargs, pg): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. + This method looks up an embedding table. + """ + input_tensor = args[0] + weight = args[1] + args = args[2:] + + if not isinstance(input_tensor, ColoTensor): + input_tensor = ColoTensor.init_from_torch_tensor(input_tensor) + + if not isinstance(weight, ColoTensor): + weight = ColoTensor.init_from_torch_tensor(weight) + + # Handle differen parallel actions. + if not weight.has_spec(): # No Model Parallel Applied + input_tensor = input_tensor.torch_tensor() + weight = weight.torch_tensor() + output = torch.nn.functional.embedding(input_tensor, weight, *args, **kwargs) + return ColoTensor.init_from_torch_tensor(output) + elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied + compute_patterns = weight.shard_spec.compute_patterns + if ComputePattern.TP1DCol_Embedding in compute_patterns: + return colo_embedding_1Dcol(input_tensor, weight, args, kwargs) + else: + raise NotImplementedError + else: + raise NotImplementedError diff --git a/colossalai/tensor/_ops/layernorm.py b/colossalai/tensor/_ops/layernorm.py index b59d5a00b..28ac286fa 100644 --- a/colossalai/tensor/_ops/layernorm.py +++ b/colossalai/tensor/_ops/layernorm.py @@ -27,7 +27,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None): eps = kwargs['eps'] if isinstance(input_tensor, ColoTensor): - if input_tensor.is_activation() and not input_tensor.is_gathered(): + if not input_tensor.is_gathered(): input_tensor.gather() input_tensor = input_tensor.torch_tensor() if isinstance(weight, ColoTensor): diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index 8dca27d8d..5e3f4934b 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -9,8 +9,8 @@ from packaging import version from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern -def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTensor) -> ColoTensor: - parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) +def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias:ColoTensor) -> ColoTensor: + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_Linear) # Input:S[1] x Weight:S[0] = Output:P # All-Reduce(Output) + bias = res # Input:S[1] @@ -47,7 +47,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # All-Gather(Output) # Input:B - parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) + parallel_action = weight.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_Linear) if input_tensor.is_gathered(): # Not splited yet. assert input_tensor.shape[-1] == weight.size(-1), \ @@ -108,9 +108,9 @@ def colo_linear(types, args, kwargs, pg): return ColoTensor.init_from_torch_tensor(torch.nn.functional.linear(input_tensor, weight, bias)) elif weight.shard_spec.num_action == 1: # Single Model Parallel Applied compute_patterns = weight.shard_spec.compute_patterns - if ComputePattern.TP1DRow in compute_patterns: + if ComputePattern.TP1DRow_Linear in compute_patterns: return colo_linear_1Drow(input_tensor, weight, bias) - elif ComputePattern.TP1DCol in compute_patterns: + elif ComputePattern.TP1DCol_Linear in compute_patterns: return colo_linear_1Dcol(input_tensor, weight, bias) else: raise NotImplementedError diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 37152f956..06a751a77 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -142,14 +142,19 @@ class ColoTensor(object): if self._shard_pattern is not ShardPattern.NA: # reshard self.gather() # Model Parameters - if ComputePattern.TP1DRow in self._shard_spec.compute_patterns: - parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow) - self._shard_1d(parallel_action=parallel_action, dim=-1) - self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). - elif ComputePattern.TP1DCol in self._shard_spec.compute_patterns: - parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol) - self._shard_1d(parallel_action=parallel_action, dim=0) - self._shard_pattern = ShardPattern.Row + if self._shard_spec.num_action == 1: + parallel_action = self._shard_spec.get_action_by_compute_pattern( + self._shard_spec.compute_patterns[0]) + if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \ + ComputePattern.TP1DCol_Embedding]: + self._shard_1d(parallel_action=parallel_action, dim=-1) + self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). + elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \ + ComputePattern.TP1DRow_Embedding]: + self._shard_1d(parallel_action=parallel_action, dim=0) + self._shard_pattern = ShardPattern.Row + else: + raise NotImplementedError def gather(self): assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.' diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 96dc414b0..1cba95e96 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -4,10 +4,12 @@ from colossalai.context.parallel_mode import ParallelMode class ComputePattern(Enum): - TP1DRow = 1 - TP1DCol = 2 - ZeRO = 3 - DP = 4 + TP1DRow_Linear = 1 + TP1DCol_Linear = 2 + TP1DRow_Embedding = 3 + TP1DCol_Embedding = 4 + ZeRO = 5 + DP = 6 class ShardPattern(Enum): @@ -43,14 +45,14 @@ class TensorSpec(object): # using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2. # parallel_action_list = [ # ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)), - # ParallelAction(1, ComputePattern.TP1DRow, gpc.get_group(ParallelMode.PARALLEL_1D)) + # ParallelAction(1, ComputePattern.TP1DRow_Linear, gpc.get_group(ParallelMode.PARALLEL_1D)) # ] # When the ColoTensor is initialized, # we first splitting tensor according to ParallelAction of ZeRO, - # then splitting tensor according to ParallelAction of TP1DRow. + # then splitting tensor according to ParallelAction of TP1DRow_Linear. # During Linear computation # Before Linear Op, we gather the tensors according to ZeRO. - # We perform Linear Op according to compute pattern of TP1DRow. + # We perform Linear Op according to compute pattern of TP1DRow_Linear. # After Linear Op, we split the tensors according to ZeRO. def __init__(self, parallel_action_list: List[ParallelAction] = [], shard_pattern: ShardPattern = ShardPattern.NA): diff --git a/tests/test_tensor/test_embedding_tp.py b/tests/test_tensor/test_embedding_tp.py new file mode 100644 index 000000000..3b145ca1a --- /dev/null +++ b/tests/test_tensor/test_embedding_tp.py @@ -0,0 +1,82 @@ +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.tensor import ColoTensor + +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction + +from _utils import check_equal, replace_parameter_add_grad, broadcast_tensor_chunk + +def run_embedding_tp1d_col_test(): + device = get_current_device() + dtype = torch.float32 + DEPTH = gpc.get_world_size(ParallelMode.PARALLEL_1D) + num_embeddings = 12 + embedding_dim = 32 + + local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer_master = torch.nn.Embedding(num_embeddings, embedding_dim) + layer = torch.nn.Embedding(num_embeddings, embedding_dim) + + A_master = torch.tensor((0,3,6,9), device=device) + A = broadcast_tensor_chunk(A_master, chunk_size=1) + + W_shape = (num_embeddings, embedding_dim) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + W = broadcast_tensor_chunk(W_master, chunk_size=1) + W.requires_grad = True + + # replace the torch nn.Parameters with ColoTensor + sharded_weight = ColoTensor.init_from_torch_tensor(W) + parallel_action_list = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, + parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec = TensorSpec(parallel_action_list) + sharded_weight.set_spec(spec) # reshard + replace_parameter_add_grad(layer, sharded_weight) + out = layer(A) + + replace_parameter_add_grad(layer_master, W_master) + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + grad = broadcast_tensor_chunk(grad_master, chunk_size=1) + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[local_rank] + check_equal(W_grad, layer.weight.grad) + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_embedding_tp1d_col_test() + +@pytest.mark.dist +@parameterize('world_size', [1, 4]) +@rerun_if_address_is_in_use() +def test_embedding_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_embedding_1d() diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index 49de72012..a57943f5d 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -47,7 +47,7 @@ def run_linear_tp1d_col_test(): sharded_weight = ColoTensor.init_from_torch_tensor(W) sharded_bias = ColoTensor.init_from_torch_tensor(B) parallel_action_list = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ] spec = TensorSpec(parallel_action_list) sharded_weight.set_spec(spec) # reshard @@ -110,7 +110,7 @@ def run_linear_tp1d_row_test(): # replace the torch nn.Parameters with ColoTensor sharded_weight = ColoTensor.init_from_torch_tensor(W) parallel_action_list = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ] spec = TensorSpec(parallel_action_list) sharded_weight.set_spec(spec=spec) # reshard @@ -145,7 +145,7 @@ def run_linear_tp1d_row_test(): def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - #run_linear_tp1d_row_test() + run_linear_tp1d_row_test() run_linear_tp1d_col_test() @pytest.mark.dist diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index ba66f1715..7610c5d8d 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -38,12 +38,12 @@ def run_1d_col_tp(): model = model_builder(checkpoint=True) parallel_action_list_row = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ] spec_row = TensorSpec(parallel_action_list_row) parallel_action_list_col = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ] spec_col = TensorSpec(parallel_action_list_col) @@ -168,7 +168,7 @@ def run_1d_row_tp(): model = model_builder(checkpoint=True) parallel_action_list = [ - ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D) + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ] spec = TensorSpec(parallel_action_list)