[Tensor] TP Linear 1D row (#843)

This commit is contained in:
Ziyue Jiang
2022-04-24 13:43:12 +08:00
committed by GitHub
parent cf6d1c9284
commit 05023ecfee
5 changed files with 154 additions and 4 deletions

View File

@@ -1,9 +1,10 @@
import torch
from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.context import ParallelMode
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input
from packaging import version
@colo_op_impl(torch.nn.functional.linear)
def colo_linear(types, args, kwargs, pg):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
@@ -19,12 +20,31 @@ def colo_linear(types, args, kwargs, pg):
bias = None
else:
bias = kwargs.get('bias', None)
if isinstance(bias, ColoTensor):
bias = bias.torch_tensor()
# Add communication logic before and after linear call.
if isinstance(weight, ColoTensor):
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
if weight.shard_spec == None:
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
elif weight.shard_spec == '1Drow':
"""
Input:S[1] x Weight:S[0] = Output:P
All-Reduce(Output) + bias = res
"""
# Input:S[1]
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
# Output:P
partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor())
# Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
output = output + bias
return output
else:
raise NotImplementedError
else:
return torch.nn.functional.linear(input_tensor, weight, bias)

View File

@@ -4,7 +4,6 @@ from typing import Tuple
import numpy
from .op_wrapper import _COLOSSAL_OPS
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
@@ -24,6 +23,7 @@ class ColoTensor(object):
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
shard_spec: str = None,
):
self._size = size
self._dtype = dtype
@@ -31,11 +31,29 @@ class ColoTensor(object):
self._pin_memory = pin_memory
self._device = device
self._torch_tensor = torch_tensor
self._shard_spec = shard_spec
@property
def shard_spec(self) -> Optional[str]:
return self._shard_spec
@property
def data(self):
return self._torch_tensor.data
@property
def grad(self):
return self._torch_tensor.grad
@property
def size(self):
return self._size
def numel(self):
return product(self._size)
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype,