mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[Tensor] TP Linear 1D row (#843)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, reduce_input
|
||||
from packaging import version
|
||||
|
||||
|
||||
@colo_op_impl(torch.nn.functional.linear)
|
||||
def colo_linear(types, args, kwargs, pg):
|
||||
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
|
||||
@@ -19,12 +20,31 @@ def colo_linear(types, args, kwargs, pg):
|
||||
bias = None
|
||||
else:
|
||||
bias = kwargs.get('bias', None)
|
||||
|
||||
|
||||
if isinstance(bias, ColoTensor):
|
||||
bias = bias.torch_tensor()
|
||||
|
||||
# Add communication logic before and after linear call.
|
||||
if isinstance(weight, ColoTensor):
|
||||
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
|
||||
if weight.shard_spec == None:
|
||||
return torch.nn.functional.linear(input_tensor, weight.torch_tensor(), bias)
|
||||
elif weight.shard_spec == '1Drow':
|
||||
"""
|
||||
Input:S[1] x Weight:S[0] = Output:P
|
||||
All-Reduce(Output) + bias = res
|
||||
"""
|
||||
# Input:S[1]
|
||||
input_per_partition = split_forward_gather_backward(input_tensor, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
# Output:P
|
||||
partial_output = torch.nn.functional.linear(input_per_partition, weight.torch_tensor())
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||
|
@@ -4,7 +4,6 @@ from typing import Tuple
|
||||
import numpy
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
1. It contains a torch.Tensor as an attribute.
|
||||
@@ -24,6 +23,7 @@ class ColoTensor(object):
|
||||
pin_memory=False,
|
||||
device=None,
|
||||
torch_tensor=torch.empty(0),
|
||||
shard_spec: str = None,
|
||||
):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
@@ -31,11 +31,29 @@ class ColoTensor(object):
|
||||
self._pin_memory = pin_memory
|
||||
self._device = device
|
||||
self._torch_tensor = torch_tensor
|
||||
self._shard_spec = shard_spec
|
||||
|
||||
@property
|
||||
def shard_spec(self) -> Optional[str]:
|
||||
return self._shard_spec
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._torch_tensor.data
|
||||
|
||||
@property
|
||||
def grad(self):
|
||||
return self._torch_tensor.grad
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return self._size
|
||||
|
||||
def numel(self):
|
||||
return product(self._size)
|
||||
|
||||
@staticmethod
|
||||
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
||||
colo_t = ColoTensor(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
|
Reference in New Issue
Block a user