diff --git a/colossalai/tensor/_ops/__init__.py b/colossalai/tensor/_ops/__init__.py index 0e2b8169d..e9ce2b1ff 100644 --- a/colossalai/tensor/_ops/__init__.py +++ b/colossalai/tensor/_ops/__init__.py @@ -2,4 +2,5 @@ from .linear import colo_linear from .element_wise import * from .layernorm import colo_layernorm from .loss import colo_cross_entropy -from .embedding import colo_embedding \ No newline at end of file +from .embedding import colo_embedding +from .addmm import colo_addmm diff --git a/colossalai/tensor/_ops/addmm.py b/colossalai/tensor/_ops/addmm.py new file mode 100644 index 000000000..7c725313a --- /dev/null +++ b/colossalai/tensor/_ops/addmm.py @@ -0,0 +1,115 @@ +import torch +from typing import Union +from colossalai.tensor.op_wrapper import colo_op_impl +from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input, reduce_grad +from colossalai.nn.layer.utils import divide +from colossalai.core import global_context as gpc +from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ShardPattern +from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv + + +def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float], + alpha: Union[int, float]) -> ColoTensor: + parallel_action = mat2.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow_mm) + # mat1:S[1] x mat2:S[0] = Output:P + # beta * input + alpha * All-Reduce(Output) = res + + # mat1:S[1] + if mat1.is_gathered(): + # Not splited yet. + assert divide(mat1.shape[-1], gpc.tensor_parallel_size) == mat2.size(0), \ + 'Invalid shapes in 1Drow forward: mat1={}, mat2={}. Expected last dim of input {}.'.format( + mat1.shape, mat2.shape, mat2.size(0) * gpc.tensor_parallel_size) + input_per_partition = split_forward_gather_backward(mat1.torch_tensor(), parallel_action.parallel_mode, dim=-1) + elif mat1.shard_pattern == ShardPattern.Col: + # Splited by 1Dcol + assert mat1.shape[-1] == mat2.size(0), \ + 'Invalid shapes in 1Drow forward: mat1={}, mat2={}. Expected last dim of input {}.'.format( + mat1.shape, mat2.shape, mat2.size(0)) + input_per_partition = mat1.torch_tensor() + else: + raise NotImplementedError + + # Output:P + partial_output = torch.mm(input_per_partition, mat2.torch_tensor()) + # Reduce(Output) + output = reduce_input(partial_output, parallel_action.parallel_mode) + # input + assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op' + output = beta * input_tensor.torch_tensor() + alpha * output + output = ColoTensor.init_from_torch_tensor(output) + return output + + +def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Union[int, float], + alpha: Union[int, float]) -> ColoTensor: + # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] + # All-Gather(Output) + # mat1:B + parallel_action = mat2.shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DCol_mm) + if mat1.is_gathered(): + # Not splited yet. + assert mat1.shape[-1] == mat2.size(0), \ + 'Invalid shapes in 1Dcol forward: mat1={}, mat2={}. Expected last dim of input {}.'.format( + mat1.shape, mat2.shape, mat2.size(0)) + input_parallel = reduce_grad(mat1.torch_tensor(), parallel_action.parallel_mode) + + # input:S[1] + assert input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1 and \ + input_tensor.shard_pattern in [ShardPattern.Col, ShardPattern.Row], \ + 'Invalid bias spec for 1Dcol Linear op' + + output_parallel = torch.addmm(input_tensor.torch_tensor(), + input_parallel, + mat2.torch_tensor(), + beta=beta, + alpha=alpha) + + output = ColoTensor.init_from_torch_tensor(output_parallel) + out_parallel_action_list = [ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)] + output_spec = TensorSpec(out_parallel_action_list) + output.set_spec(output_spec, shard=False) + output.set_shard_pattern(ShardPattern.Col) + if parallel_action.gather_out: + # All-Gather(Output) + output.gather() + return output + + +@colo_op_impl(torch.addmm) +def colo_addmm(types, args, kwargs, pg): + """Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``. + This method computes a linear. + """ + input_tensor, mat1, mat2 = tuple( + map(lambda t: t if isinstance(t, ColoTensor) else ColoTensor.init_from_torch_tensor(t), args[:3])) + beta = kwargs.get('beta', 1) if kwargs else 1 + alpha = kwargs.get('alpha', 1) if kwargs else 1 + + # building the computing graph, inputs -> op + # if GraphGlobalEnv().graph_building: + # cur_op_node = GraphOpNode('linear', [weight, bias]) + # cur_op_node.add_prev_tensor(input_tensor) + + # Add communication logic before and after linear call. + ret_tensor = None + if not mat2.has_spec(): # No Model Parallel Applied + assert not input_tensor.has_spec(), 'Invalid input spec for native addmm op' + ret_tensor = ColoTensor.init_from_torch_tensor( + torch.addbmm(input_tensor.torch_tensor(), mat1.torch_tensor(), mat2.torch_tensor(), beta=beta, alpha=alpha)) + elif mat2.shard_spec.num_action == 1: # Single Model Parallel Applied + compute_patterns = mat2.shard_spec.compute_patterns + if ComputePattern.TP1DRow_mm in compute_patterns: + ret_tensor = colo_addmm_1Drow(input_tensor, mat1, mat2, beta, alpha) + elif ComputePattern.TP1DCol_mm in compute_patterns: + ret_tensor = colo_addmm_1Dcol(input_tensor, mat1, mat2, beta, alpha) + else: + raise NotImplementedError + else: + raise NotImplementedError + + # building the computing graph, op -> output + # if GraphGlobalEnv().graph_building: + # cur_op_node.add_post_tensor(ret_tensor) + + return ret_tensor diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index f3a542ff6..a0cb1bfe2 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -142,12 +142,15 @@ class ColoTensor(object): # Model Parameters if self._shard_spec.num_action == 1: parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0]) - if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \ - ComputePattern.TP1DCol_Embedding]: + if parallel_action.compute_pattern in [ + ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm + ]: self._shard_1d(parallel_action=parallel_action, dim=-1) - self._shard_pattern = ShardPattern.Col # We bind our ComputePattern on weight, which has to be transposed when linear(). - elif parallel_action.compute_pattern in [ComputePattern.TP1DCol_Linear, \ - ComputePattern.TP1DRow_Embedding]: + # We bind our ComputePattern on weight, which has to be transposed when linear(). + self._shard_pattern = ShardPattern.Col + elif parallel_action.compute_pattern in [ + ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm + ]: self._shard_1d(parallel_action=parallel_action, dim=0) self._shard_pattern = ShardPattern.Row else: diff --git a/colossalai/tensor/spec.py b/colossalai/tensor/spec.py index 1cba95e96..eb42fdf0e 100644 --- a/colossalai/tensor/spec.py +++ b/colossalai/tensor/spec.py @@ -8,8 +8,10 @@ class ComputePattern(Enum): TP1DCol_Linear = 2 TP1DRow_Embedding = 3 TP1DCol_Embedding = 4 - ZeRO = 5 - DP = 6 + TP1DRow_mm = 5 + TP1DCol_mm = 6 + ZeRO = 7 + DP = 8 class ShardPattern(Enum): diff --git a/tests/test_tensor/test_addmm_tp.py b/tests/test_tensor/test_addmm_tp.py new file mode 100644 index 000000000..1bd99fbd6 --- /dev/null +++ b/tests/test_tensor/test_addmm_tp.py @@ -0,0 +1,81 @@ +import colossalai +import torch +import pytest +import torch.nn as nn +import torch.multiprocessing as mp +from colossalai.utils import ColoInitContext +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction +from colossalai.context import ParallelMode +from colossalai.utils.cuda import get_current_device +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from functools import partial + + +class Conv1D(nn.Module): + """ + 1D-convolutional layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2). + Basically works like a linear layer but the weights are transposed. + Args: + nf (`int`): The number of output features. + nx (`int`): The number of input features. + """ + + def __init__(self, nf, nx): + super().__init__() + self.nf = nf + w = torch.empty(nx, nf) + nn.init.normal_(w, std=0.02) + self.weight = nn.Parameter(w) + self.bias = nn.Parameter(torch.ones(nf)) + + def forward(self, x): + size_out = x.size()[:-1] + (self.nf,) + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) + x = x.view(size_out) + return x + + +def init_1d_row(model): + spec = TensorSpec( + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_mm, parallel_mode=ParallelMode.PARALLEL_1D)]) + for n, p in model.colo_named_parameters(): + if 'weight' in n: + p.set_spec(spec) + + +def init_1d_col(model): + spec = TensorSpec( + [ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_mm, parallel_mode=ParallelMode.PARALLEL_1D)]) + for n, p in model.colo_named_parameters(): + p.set_spec(spec) + + +def run_with_spec(spec_init_func): + with ColoInitContext(device=get_current_device()): + model = Conv1D(4, 16) + weight = model.weight.torch_tensor().clone() + bias = model.bias.torch_tensor().clone() + spec_init_func(model) + x = torch.rand(2, 16).cuda() + out = model(x) + assert torch.allclose(out.torch_tensor(), torch.addmm(bias, x, weight)) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_with_spec(init_1d_row) + run_with_spec(init_1d_col) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [1, 2, 4]) +@rerun_if_address_is_in_use() +def test_addmm_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_addmm_1d(2)