mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Tensor] Add function to spec and update linear 1Drow and unit tests (#869)
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
from colossalai.context import parallel_mode
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
from numpy import product
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||
|
||||
class ColoTensor(object):
|
||||
""" Data Structure for Tensor in Colossal-AI
|
||||
@@ -28,7 +27,7 @@ class ColoTensor(object):
|
||||
pin_memory=False,
|
||||
device=None,
|
||||
torch_tensor=torch.empty(0),
|
||||
shard_spec: str = None,
|
||||
shard_spec: TensorSpec = TensorSpec(),
|
||||
):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
@@ -39,7 +38,7 @@ class ColoTensor(object):
|
||||
self._shard_spec = shard_spec
|
||||
|
||||
@property
|
||||
def shard_spec(self) -> Optional[str]:
|
||||
def shard_spec(self) -> TensorSpec:
|
||||
return self._shard_spec
|
||||
|
||||
@property
|
||||
@@ -109,27 +108,27 @@ class ColoTensor(object):
|
||||
device=self._device)
|
||||
return self._torch_tensor
|
||||
|
||||
def set_spec(self, spec: str, lazy_shard: bool = False) -> None:
|
||||
def set_spec(self, spec: TensorSpec, lazy_shard: bool = False) -> None:
|
||||
self._shard_spec = spec
|
||||
if lazy_shard == False:
|
||||
self._shard()
|
||||
|
||||
def _shard(self):
|
||||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||
if self._shard_spec == "1Drow": # TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
|
||||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
device = get_current_device()
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
self._device = device # TODO A `fake` device now because torch_tensor.device always = cpu
|
||||
if self._shard_spec.num_action == 1:
|
||||
if ComputePattern.TP1DRow in self._shard_spec.compute_patterns:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.TP1DRow)
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
dim = -1
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach(
|
||||
).contiguous() # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
@@ -151,5 +150,5 @@ class ColoTensor(object):
|
||||
kwargs = {k: v.torch_tensor() if isinstance(v, ColoTensor) else v for k, v in kwargs.items()}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def backward(self, retain_graph: bool = False):
|
||||
self._torch_tensor.backward(retain_graph=retain_graph)
|
||||
def backward(self, gradient: Optional[torch.Tensor] = None , retain_graph: bool = False):
|
||||
self._torch_tensor.backward(gradient=gradient, retain_graph=retain_graph)
|
||||
|
Reference in New Issue
Block a user