mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[tensor] design DistSpec and DistSpecManager for ColoTensor (#934)
* add dist spec * update linear op * polish code * polish code * update embedding op * polish unit tests * polish unit tests * polish comments * polish code * add test_dist_spec_mgr * polish code * refactor folder structure * polish unit tests * add get_process_group() for TensorSpec * polish code
This commit is contained in:
@@ -1,13 +1,16 @@
|
||||
from .op_wrapper import _COLOSSAL_OPS
|
||||
|
||||
from copy import copy
|
||||
import torch
|
||||
from typing import Tuple, Optional, Callable, Union
|
||||
from numpy import product
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern
|
||||
from colossalai.tensor import TensorSpec, ComputePattern
|
||||
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward
|
||||
from .const import TensorType
|
||||
from colossalai.tensor import dist_spec
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.dist_spec import _DistSpec
|
||||
|
||||
|
||||
class ColoTensor(object):
|
||||
@@ -28,15 +31,14 @@ class ColoTensor(object):
|
||||
pin_memory=False,
|
||||
device=None,
|
||||
torch_tensor=torch.empty(0),
|
||||
shard_spec: TensorSpec = TensorSpec()):
|
||||
spec: TensorSpec = TensorSpec(dist_spec.replicate())):
|
||||
self._size = size
|
||||
self._dtype = dtype
|
||||
self._requires_grad = requires_grad
|
||||
self._pin_memory = pin_memory
|
||||
self._device = device
|
||||
self._torch_tensor = torch_tensor
|
||||
self._shard_spec = shard_spec
|
||||
self._shard_pattern = ShardPattern.NA
|
||||
self._spec = copy(spec)
|
||||
self._type = TensorType.NONMODEL
|
||||
self._graph_node = None
|
||||
|
||||
@@ -44,8 +46,8 @@ class ColoTensor(object):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor()[key])
|
||||
|
||||
@property
|
||||
def shard_spec(self) -> TensorSpec:
|
||||
return self._shard_spec
|
||||
def spec(self) -> TensorSpec:
|
||||
return self._spec
|
||||
|
||||
@property
|
||||
def shard_pattern(self):
|
||||
@@ -96,13 +98,16 @@ class ColoTensor(object):
|
||||
return product(self._size)
|
||||
|
||||
@staticmethod
|
||||
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
|
||||
def init_from_torch_tensor(tensor: torch.Tensor,
|
||||
save_payload=True,
|
||||
spec: TensorSpec = TensorSpec(dist_spec.replicate())) -> 'ColoTensor':
|
||||
colo_t = ColoTensor(*tensor.size(),
|
||||
dtype=tensor.dtype,
|
||||
requires_grad=tensor.requires_grad,
|
||||
pin_memory=tensor.is_pinned(),
|
||||
device=tensor.device,
|
||||
torch_tensor=tensor if save_payload else torch.empty(0))
|
||||
torch_tensor=tensor if save_payload else torch.empty(0),
|
||||
spec=spec)
|
||||
return colo_t
|
||||
|
||||
def del_torch_tensor(self, save_shape=False) -> None:
|
||||
@@ -127,85 +132,17 @@ class ColoTensor(object):
|
||||
device=self._device)
|
||||
return self._torch_tensor
|
||||
|
||||
def set_spec(self, spec: TensorSpec, shard: bool = True) -> None:
|
||||
self._shard_spec = spec
|
||||
if shard == True:
|
||||
self.shard()
|
||||
|
||||
def set_shard_pattern(self, shard_pattern: ShardPattern):
|
||||
self._shard_pattern = shard_pattern
|
||||
|
||||
def shard(self):
|
||||
assert self._shard_spec is not None, 'You should call set_spec() before _shard() ColoTensor.'
|
||||
if self._shard_pattern is not ShardPattern.NA: # reshard
|
||||
self.gather()
|
||||
# Model Parameters
|
||||
if self._shard_spec.num_action == 1:
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
|
||||
if parallel_action.compute_pattern in [
|
||||
ComputePattern.TP1DRow_Linear, ComputePattern.TP1DCol_Embedding, ComputePattern.TP1DCol_mm
|
||||
]:
|
||||
self._shard_1d(parallel_action=parallel_action, dim=-1)
|
||||
# We bind our ComputePattern on weight, which has to be transposed when linear().
|
||||
self._shard_pattern = ShardPattern.Col
|
||||
elif parallel_action.compute_pattern in [
|
||||
ComputePattern.TP1DCol_Linear, ComputePattern.TP1DRow_Embedding, ComputePattern.TP1DRow_mm
|
||||
]:
|
||||
self._shard_1d(parallel_action=parallel_action, dim=0)
|
||||
self._shard_pattern = ShardPattern.Row
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def gather(self):
|
||||
assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
|
||||
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||
dim = self._get_gather_dim()
|
||||
self._torch_tensor = gather_forward_split_backward(self._torch_tensor, parallel_action.parallel_mode, dim=dim)
|
||||
self._shard_pattern = ShardPattern.NA
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
def global_torch_tensor(self) -> torch.Tensor:
|
||||
out_tensor = self.torch_tensor()
|
||||
if self.is_gathered():
|
||||
return out_tensor
|
||||
|
||||
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
|
||||
world_size = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
if world_size == 1:
|
||||
return out_tensor
|
||||
|
||||
rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
tensor_list = [torch.empty_like(out_tensor) for _ in range(world_size)]
|
||||
tensor_list[rank] = out_tensor
|
||||
torch.distributed.all_gather(tensor_list, out_tensor, group=gpc.get_group(parallel_action.parallel_mode))
|
||||
|
||||
dim = self._get_gather_dim()
|
||||
out_tensor = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return out_tensor
|
||||
|
||||
def is_gathered(self) -> bool:
|
||||
return self._shard_pattern == ShardPattern.NA
|
||||
def set_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self.to_dist_spec(spec.dist_spec)
|
||||
self._spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
return self._shard_spec is not None and self._shard_spec.num_action > 0
|
||||
return self._spec.num_action > 0
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def _shard_1d(self, parallel_action, dim=-1):
|
||||
num_partition = gpc.get_world_size(parallel_action.parallel_mode)
|
||||
local_rank = gpc.get_local_rank(parallel_action.parallel_mode)
|
||||
chunk_size = divide(self._size[dim], num_partition)
|
||||
# Reshape to get shard for this rank and we don't want autograd
|
||||
# recording here for the narrow op and 'local_shard' should be a
|
||||
# leaf variable in the autograd graph.
|
||||
self._torch_tensor = self._torch_tensor.narrow(dim, local_rank * chunk_size, chunk_size).detach().contiguous(
|
||||
) # TODO Shall we clone() here since detach() will point to the old tensor?
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
global _COLOSSAL_OPS
|
||||
@@ -278,15 +215,6 @@ class ColoTensor(object):
|
||||
for output in outputs
|
||||
])
|
||||
|
||||
def _get_gather_dim(self):
|
||||
if self._shard_pattern == ShardPattern.Row:
|
||||
dim = 0
|
||||
elif self._shard_pattern == ShardPattern.Col:
|
||||
dim = -1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return dim
|
||||
|
||||
def __mul__(self, other) -> "ColoTensor":
|
||||
if isinstance(other, ColoTensor):
|
||||
return ColoTensor.init_from_torch_tensor(self.torch_tensor() * other.torch_tensor())
|
||||
@@ -296,3 +224,10 @@ class ColoTensor(object):
|
||||
raise TypeError(f'{type(other)} is not supported in ColoTensor __mul__')
|
||||
|
||||
__rmul__ = __mul__
|
||||
|
||||
def to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
self._torch_tensor = DistSpecManager.handle_trans_spec(self.torch_tensor(), self.spec.dist_spec, dist_spec)
|
||||
if self._torch_tensor.is_leaf:
|
||||
self._torch_tensor.requires_grad = self._requires_grad
|
||||
self._size = self._torch_tensor.size()
|
||||
self._spec.dist_spec = dist_spec
|
||||
|
Reference in New Issue
Block a user