[ColoTensor] improves init functions. (#1150)

This commit is contained in:
Jiarui Fang 2022-06-21 18:28:38 +08:00 committed by GitHub
parent 8106d7b8c7
commit 8cdce0399c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 103 additions and 40 deletions

View File

@ -35,7 +35,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
data: Optional[torch.Tensor] = None, data: Optional[torch.Tensor] = None,
requires_grad: bool = True, requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec) self._tensor_spec = copy(spec)
self._type = TensorType.MODEL self._type = TensorType.MODEL
self._graph_node = None self._graph_node = None

View File

@ -1,12 +1,13 @@
from .op_wrapper import _COLOSSAL_OPS from .op_wrapper import _COLOSSAL_OPS
from .const import TensorType
from copy import copy from copy import copy
import torch import torch
from torch.overrides import get_default_nowrap_functions
from colossalai.tensor import TensorSpec from colossalai.tensor import TensorSpec
from .const import TensorType
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec from colossalai.tensor.distspec import _DistSpec
from torch.overrides import get_default_nowrap_functions
def _convert_output(output): def _convert_output(output):
@ -18,34 +19,54 @@ def _convert_output(output):
class ColoTensor(torch.Tensor): class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
1. It contains a torch.Tensor as an attribute. Args:
2. It supports lazy init the tensor's payload. data (torch.Tensor): a torch tensor used as the payload the colotensor.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions. spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
""" """
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad) return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec) self._tensor_spec = copy(spec)
self._type = TensorType.NONMODEL self._type = TensorType.NONMODEL
self._graph_node = None self._graph_node = None
@property @property
def spec(self) -> TensorSpec: def spec(self) -> TensorSpec:
return self._spec return self._tensor_spec
def set_spec(self, spec: TensorSpec) -> None: def set_spec(self, spec: TensorSpec) -> None:
spec = copy(spec) spec = copy(spec)
self.convert_to_dist_spec_(spec.dist_spec) self._convert_to_dist_spec(spec.dist_spec)
self._spec = spec self._tensor_spec = spec
def has_spec(self) -> bool: def has_spec(self) -> bool:
return self._spec.parallel_action is not None return self._tensor_spec.parallel_action is not None
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL
@ -74,16 +95,16 @@ class ColoTensor(torch.Tensor):
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None: def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._spec.dist_spec = dist_spec self._tensor_spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
spec = copy(self._spec) tensor_spec = copy(self._tensor_spec)
spec.dist_spec = dist_spec tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, spec) return ColoTensor.from_torch_tensor(ret, tensor_spec)
@staticmethod @staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':

View File

@ -4,6 +4,7 @@ from numpy import prod
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging import version
# TODO(jiaruifang) circle import, move the divide to colossalai.commons. # TODO(jiaruifang) circle import, move the divide to colossalai.commons.
@ -56,6 +57,12 @@ class DistSpecManager:
@staticmethod @staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor: def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
if version.parse(torch.__version__) < version.parse("1.11.0"):
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())] buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group) dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
for i in range(len(old_dist_spec.dims) - 1, -1, -1): for i in range(len(old_dist_spec.dims) - 1, -1, -1):
@ -66,6 +73,9 @@ class DistSpecManager:
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim)) new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
buffer = new_buffer buffer = new_buffer
assert len(buffer) == 1 assert len(buffer) == 1
if version.parse(torch.__version__) < version.parse("1.11.0"):
buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0] return buffer[0]
@staticmethod @staticmethod

View File

@ -24,28 +24,13 @@ class ParallelAction(object):
class TensorSpec(object): class TensorSpec(object):
""" """
It contains two aspects of information: The specification of the ColoTensor.
First, How are tensors distributed in Heterougenous memory space. Args:
Second, if the tensor is a model parameter, the Spec contains the dist_spec (_DistSpec): descriping the layout among processes.
parallel computation pattern of the Operator (Layer). parallel_action (Optional[ParallelAction], optional): actions conducted on the tensor after initialization if it's a model data tensor.
We have to consider the hybrid parallel mode. Defaults to None.
""" """
# a list of parallel actions.
# For example: On 8 GPUs, a hybrid parallel strategy is applied using
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1D_Linear.
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None): def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
self.parallel_action = parallel_action self.parallel_action = parallel_action
self.dist_spec = dist_spec self.dist_spec = dist_spec

View File

@ -3,6 +3,17 @@ import pytest
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from numpy import allclose from numpy import allclose
import colossalai
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec
from colossalai.core import global_context as gpc
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import distspec, TensorSpec, ColoTensor
from colossalai.context import ParallelMode
from functools import partial
def test_tensor_indexing(): def test_tensor_indexing():
torch_t = torch.randn(2, 3) torch_t = torch.randn(2, 3)
@ -25,8 +36,6 @@ def test_wrapped_tensor_func():
# non-func attr # non-func attr
assert t.is_cuda == t_ref.is_cuda assert t.is_cuda == t_ref.is_cuda
# TODO I don't find out a tensor function which returns None.
# return 1 torch.Tensor # return 1 torch.Tensor
t_abs = t.abs() t_abs = t.abs()
assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs()) assert isinstance(t_abs, ColoTensor) and torch.equal(t_abs, t_ref.abs())
@ -47,3 +56,41 @@ def test_operand():
t_res = t + t t_res = t + t
assert torch.allclose(t_ref_res, t_res) assert torch.allclose(t_ref_res, t_res)
#### Test Distributed init a Colotensor
def _run_tensor_shard_init(world_size):
t_ref = torch.randn(4, 5)
print(gpc.get_group(ParallelMode.DATA).size())
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5))
def _run_tensor_replicated_init(world_size):
t_ref = torch.randn(4 * world_size, 5)
t = ColoTensor.from_torch_tensor(t_ref.clone())
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
def run_tensor_init(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
_run_tensor_shard_init(world_size)
_run_tensor_replicated_init(world_size)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def _test_dist_init(world_size):
run_func = partial(run_tensor_init, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
# _test_dist_init(4)
test_new()