diff --git a/colossalai/tensor/_ops/linear.py b/colossalai/tensor/_ops/linear.py index d8bc338a5..d85893969 100644 --- a/colossalai/tensor/_ops/linear.py +++ b/colossalai/tensor/_ops/linear.py @@ -19,12 +19,18 @@ def colo_linear(types, args, kwargs, pg): bias = None else: bias = kwargs.get('bias', None) - + if isinstance(bias, ColoTensor): bias = bias.torch_tensor() # Add communication logic before and after linear call. if isinstance(weight, ColoTensor): - return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) + if weight.shard_spec == None: + return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias) + elif weight.shard_spec == '1Drow': + # TODO(jzy): implement 1Drow TP linear here. + raise NotImplementedError + else: + raise NotImplementedError else: return torch.nn.functional.linear(input_tensor, weight, bias) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index f40034dc1..8900a42ff 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -1,6 +1,6 @@ import torch from .op_wrapper import _COLOSSAL_OPS -from typing import Tuple +from typing import Tuple, Optional class ColoTensor(object): @@ -21,20 +21,35 @@ class ColoTensor(object): requires_grad=False, pin_memory=False, torch_tensor=torch.empty(0), + shard_spec: str = None, ): self._size = size self._dtype = dtype self._requires_grad = requires_grad self._pin_memory = pin_memory self._torch_tensor = torch_tensor + self._shard_spec = shard_spec + + @property + def shard_spec(self) -> Optional[str]: + return self._shard_spec + + @property + def data(self): + return self._torch_tensor.data + + @property + def grad(self): + return self._torch_tensor.grad @staticmethod - def init_from_torch_tensor(tensor: torch.Tensor): + def init_from_torch_tensor(tensor: torch.Tensor, shard_spec: str = None) -> 'ColoTensor': colo_t = ColoTensor(*tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, pin_memory=tensor.pin_memory, - torch_tensor=tensor) + torch_tensor=tensor, + shard_spec=shard_spec) return colo_t def del_torch_tensor(self) -> None: @@ -67,7 +82,5 @@ class ColoTensor(object): if kwargs is None: kwargs = {} - kwargs = { - k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k,v in kwargs.items() - } + kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()} return func(*args, **kwargs) diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py new file mode 100644 index 000000000..4adb848b1 --- /dev/null +++ b/tests/test_tensor/test_linear_tp.py @@ -0,0 +1,74 @@ +from joblib import Parallel +from numpy import allclose, require +import torch +from colossalai.context.parallel_mode import ParallelMode +from colossalai.tensor import ColoTensor +from copy import deepcopy + +from functools import partial + +import colossalai +import pytest +import torch +import torch.multiprocessing as mp +from colossalai.logging import get_dist_logger +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.core import global_context as gpc + + +def run_linear_tp1d_row_test(): + in_dim = 4 + out_dim = 5 + + fc = torch.nn.Linear(in_dim, out_dim, bias=True) + fc_ref = deepcopy(fc) + + input_ref = torch.randn(1, in_dim) + input_tensor = input_ref.clone() + + # sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight, "1Drow") + + # shard weight at begiin + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + sharded_weight = ColoTensor(in_dim / world_size, out_dim, shard_spec="1Drow") + sharded_bias = ColoTensor.init_from_torch_tensor(fc_ref.bias) + + # replace the torch nn.Parameters with ShardedTensor + delattr(fc, 'weight') + setattr(fc, 'weight', sharded_weight) + delattr(fc, 'bias') + setattr(fc, 'bias', sharded_bias) + + fc.weight.requires_grad = True + fc.bias.requires_grad = True + + # torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias) + out = fc(input_tensor) + loss = out.sum() + loss.backward() + + out_ref = fc_ref(input_ref) + loss_ref = out_ref.sum() + loss_ref.backward() + + assert (loss_ref == loss) + assert allclose(fc_ref.weight.grad, fc.weight.torch_tensor().grad) + + +def run_dist(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_linear_tp1d_row_test() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_linear_1d(world_size): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_linear_1d(4)