[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)

This commit is contained in:
Ziyue Jiang
2022-04-26 10:15:26 +08:00
committed by GitHub
parent 11f54c7b6b
commit 26d4ab8b03
6 changed files with 85 additions and 58 deletions

View File

@@ -1,13 +1,12 @@
from colossalai.context import parallel_mode
from .op_wrapper import _COLOSSAL_OPS
import torch
from typing import Tuple, Optional
from numpy import product
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.nn.layer.utils import divide
from colossalai.utils.cuda import get_current_device
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
class ColoTensor(object):
""" Data Structure for Tensor in Colossal-AI
@@ -28,7 +27,7 @@ class ColoTensor(object):
pin_memory=False,
device=None,
torch_tensor=torch.empty(0),
shard_spec: str = None,
shard_spec: TensorSpec = TensorSpec(),
):
self._size = size
self._dtype = dtype
@@ -39,7 +38,7 @@ class ColoTensor(object):
self._shard_spec = shard_spec
@property
def shard_spec(self) -> Optional[str]:
def shard_spec(self) -> TensorSpec:
return self._shard_spec
@property
@@ -109,27 +108,27 @@ class ColoTensor(object):
device=self._device)
return self._torch_tensor
def set_spec(self, spec: str, lazy_shard: bool = False) -> None:
def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None:
self._shard_spec = spec
if lazy_shard == False:
self._shard()
def _shard(self):
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
dim = -1
chunk_size = divide(self._size[dim], num_partition)
device = get_current_device()
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
if self._shard_spec.num_action == 1:
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
dim = -1
chunk_size = divide(self._size[dim], num_partition)
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
self._torch_tensor.requires_grad = self._requires_grad
self._size = self._torch_tensor.size()
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -151,5 +150,5 @@ class ColoTensor(object):
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
return func(*args, **kwargs)
def backward(self, retain_graph: bool = False):
self._torch_tensor.backward(retain_graph=retain_graph)
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)