[refactor] move process group from _DistSpec to ColoTensor. (#1203)

This commit is contained in:
Jiarui Fang
2022-07-06 16:15:16 +08:00
committed by GitHub
parent 5da87ce35d
commit ae7d3f4927
34 changed files with 452 additions and 367 deletions

View File

@@ -1,5 +1,5 @@
from .process_group import ProcessGroup
from .tensor_spec import TensorSpec
from .tensor_spec import ColoTensorSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
@@ -9,7 +9,7 @@ from .param_op_hook import ParamOpHook, ParamOpHookManager
from . import distspec
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
'ProcessGroup'
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
'ColoTensorSpec', 'TensorSpec'
]

View File

@@ -5,7 +5,7 @@ from copy import copy
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType
from colossalai.tensor import TensorSpec, distspec
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
@@ -28,7 +28,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __new__(cls,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
spec: ColoTensorSpec = None) -> 'ColoParameter':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, requires_grad)
@@ -36,11 +36,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
def __init__(self,
data: Optional[torch.Tensor] = None,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._tensor_spec = copy(spec)
spec: ColoTensorSpec = None) -> None:
ColoTensor.__init__(self, data, spec)
self._type = TensorType.MODEL
self._graph_node = None
# a list contains modules sharing this ColoParameter with others.
self._shared_param_modules = []
@@ -51,7 +49,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
@staticmethod
def from_torch_tensor(tensor: torch.Tensor,
requires_grad: bool = True,
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
spec: ColoTensorSpec = None) -> 'ColoParameter':
tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor
@@ -82,7 +80,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
tensor = ColoParameter(data,
self.requires_grad,
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
memo[id(self)] = tensor
return tensor

View File

@@ -4,18 +4,18 @@ from copy import copy
import torch
from torch.overrides import get_default_nowrap_functions
from colossalai.tensor import TensorSpec
from colossalai.tensor import distspec
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor import distspec, ProcessGroup
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from typing import Optional
def _convert_output(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output = ColoTensor.from_torch_tensor(output)
def _check_output(output):
if not isinstance(output, torch.Tensor):
raise RuntimeError
elif isinstance(output, (list, tuple)):
output = type(output)(_convert_output(o) for o in output)
output = type(output)(_check_output(o) for o in output)
return output
@@ -23,28 +23,29 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
"""
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
"""__new__
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
spec (TensorSpec, optional): the tensor spec of initialization.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
@@ -52,37 +53,72 @@ class ColoTensor(torch.Tensor):
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._tensor_spec = copy(spec)
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
# If not set spec, use a DP process group and replicate dist spec
if not spec:
self.has_initialized = False
self.dist_spec = distspec.replicate()
self.compute_spec = None
self.process_group = ProcessGroup()
else:
self.has_initialized = True
self.dist_spec = spec.dist_attr
self.compute_spec = spec.compute_attr
self.process_group = spec.pg
self._type = TensorType.NONMODEL
self._graph_node = None
@property
def tensor_spec(self) -> TensorSpec:
return self._tensor_spec
@tensor_spec.setter
def tensor_spec(self, tenseor_spec: TensorSpec):
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def set_tensor_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def has_compute_spec(self) -> bool:
return self._tensor_spec.compute_spec is not None
return self.compute_spec is not None
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self._tensor_spec.dist_spec.process_group
return self.process_group
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
Args:
pg (ProcessGroup): target pg
Raises:
RuntimeError:
RuntimeError:
"""
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
if self.process_group.tp_world_size() != 1:
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
if self.dist_spec.placement.value != 'r':
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
self.process_group = pg
def get_tp_world_size(self) -> int:
return self._tensor_spec.dist_spec.process_group.tp_world_size()
return self.process_group.tp_world_size()
def set_dist_spec(self, dist_spec: _DistSpec):
"""set_dist_spec
set dist spec and change the payloads.
Args:
dist_spec (_DistSpec): target dist spec.
"""
assert isinstance(dist_spec, _DistSpec)
self._convert_to_dist_spec(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec:
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
self.set_dist_spec(dist_spec)
if compute_spec:
self.compute_spec = compute_spec
def has_compute_pattern(self, compute_pattern):
return self.compute_spec.compute_pattern == compute_pattern
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
@@ -100,7 +136,9 @@ class ColoTensor(torch.Tensor):
if func in get_default_nowrap_functions():
return ret
else:
return _convert_output(ret)
# TODO(jiaruifang) its parallel Op's duty to convert output activations
return ret
# return _check_output(ret)
def __repr__(self):
return f'ColoTensor: {super().__repr__()}'
@@ -113,30 +151,28 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): the target dist. spec.
"""
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
self._tensor_spec.dist_spec = dist_spec
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
tensor_spec = copy(self._tensor_spec)
tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, tensor_spec)
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
self.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
return self.convert_to_dist_spec(distspec.replicate())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
@@ -147,7 +183,7 @@ class ColoTensor(torch.Tensor):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
memo[id(self)] = tensor
return tensor
@@ -165,12 +201,13 @@ class ColoTensor(torch.Tensor):
Returns:
ColoTensor: a tensor after viewed.
"""
if self.tensor_spec.is_replicate():
if self.is_replicate():
return super().view(*args)
# TODO(jiaruifang) check why this not work
# self.data = self.to_replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
self.process_group)
self.dist_spec = distspec.replicate()
return super().view(*args)
def size_global(self, args: Optional[int] = None):
@@ -179,13 +216,13 @@ class ColoTensor(torch.Tensor):
Returns:
ColoTensor: a tensor after viewed.
"""
if self.tensor_spec.is_replicate():
if self.is_replicate():
if args is not None:
return super().size(args)
else:
return super().size()
spec = self.tensor_spec.dist_spec
spec = self.dist_spec
dims = spec.dims
num_partitions = spec.num_partitions
# import inspect
@@ -198,3 +235,19 @@ class ColoTensor(torch.Tensor):
return size_list[args]
else:
return torch.Size(size_list)
# Some API for dist spec check
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0

View File

@@ -6,6 +6,7 @@ import torch
import torch.distributed as dist
from packaging import version
from colossalai.logging import get_dist_logger
from colossalai.tensor import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
@@ -29,15 +30,17 @@ def divide(numerator, denominator):
class TransformDistSpec(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, old_dist_spec, dist_spec, forward_trans_func, backward_trans_func):
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
ctx.old_dist_spec = old_dist_spec
ctx.dist_spec = dist_spec
ctx.backward_trans_func = backward_trans_func
return forward_trans_func(tensor, old_dist_spec, dist_spec)
ctx.pg = pg
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def backward(ctx, grad_outputs):
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec), None, None, None, None
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
ctx.pg), None, None, None, None, None
class DistSpecManager:
@@ -46,18 +49,17 @@ class DistSpecManager:
@staticmethod
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
and dist_spec.process_group is not None:
raise NotImplementedError
pass
@staticmethod
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
"""_shard_as: shard the tensor w.r.t a distributed specification.
Assuming the tensor passed in is a global (replicated) tensor.
Args:
tensor (torch.Tensor): a global (replicated) tensor before shard
dist_spec (_DistSpec): the distributed spec. to be sharded as.
pg (ProcessGrouo): the process group of the corresponding colotensor
Returns:
torch.Tensor: a torch tensor after sharded.
"""
@@ -65,7 +67,7 @@ class DistSpecManager:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = dist_spec.process_group.tp_local_rank()
idx = pg.tp_local_rank()
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
@@ -76,7 +78,7 @@ class DistSpecManager:
return chunk.clone().detach().contiguous()
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
"""_gather gather sharded tensors to a replicated one.
Args:
tensor (torch.Tensor): a shared torch tensor
@@ -92,9 +94,9 @@ class DistSpecManager:
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == 'cuda'
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
@@ -109,12 +111,14 @@ class DistSpecManager:
return buffer[0]
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
world_size = old_dist_spec.process_group.tp_world_size()
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
assert tensor.device.type == "cuda", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device"
gather_dim = old_dist_spec.dims[0]
@@ -126,46 +130,50 @@ class DistSpecManager:
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
return output_
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return tensor
@staticmethod
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec, pg)
@staticmethod
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
if old_dist_spec == dist_spec:
return tensor
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
# use all-to-all to save memory
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec)
tensor = DistSpecManager._gather(tensor, old_dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
if not DistSpecManager._use_autograd_function:
return forward_trans_handle(tensor, old_dist_spec, dist_spec)
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
backward_trans_handle = getattr(DistSpecManager,
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}')
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, forward_trans_handle, backward_trans_handle)
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
backward_trans_handle)
@staticmethod
@contextmanager

View File

@@ -1,7 +1,5 @@
from enum import Enum
from colossalai.tensor import ProcessGroup
from typing import Optional, List
from numpy import prod
from typing import List
__all__ = ['replicate', 'shard']
@@ -13,10 +11,7 @@ class DistPlacementPattern(Enum):
class _DistSpec:
def __init__(self,
dist_placement_pattern: DistPlacementPattern,
process_group: Optional[ProcessGroup] = None,
**meta_info):
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
"""_DistSpec, Distributed Specification
Args:
@@ -25,7 +20,6 @@ class _DistSpec:
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
self.placement = dist_placement_pattern
self.process_group = process_group
for k, v in meta_info.items():
setattr(self, k, v)
@@ -45,14 +39,11 @@ class _DistSpec:
return res
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
# process_group=None means global process group
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)
def replicate() -> _DistSpec:
return _DistSpec(DistPlacementPattern.REPLICATE)
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
assert process_group is not None and isinstance(process_group, ProcessGroup)
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@@ -3,6 +3,7 @@ from contextlib import contextmanager
from abc import ABC, abstractmethod
from typing import List, Tuple, Any
from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec
class ParamOpHook(ABC):
@@ -129,7 +130,7 @@ def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.tensor_spec))
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
else:
info.append(None)
return info

View File

@@ -20,6 +20,9 @@ class ProcessGroup:
ranks: Optional[List[int]] = None,
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
if not torch.distributed.is_initialized():
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
if rank is None:
self._rank = torch.distributed.get_rank()

View File

@@ -1,44 +1,12 @@
import torch.distributed as dist
from typing import Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from .compute_spec import ComputeSpec, ComputePattern
from .compute_spec import ComputeSpec
from colossalai.tensor import ProcessGroup
from dataclasses import dataclass
class TensorSpec(object):
"""
The specification of the ColoTensor.
Args:
dist_spec (_DistSpec): descriping the layout among processes.
compute_spec (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
Defaults to None.
"""
def __init__(self, dist_spec: _DistSpec, compute_spec: Optional[ComputeSpec] = None):
self.compute_spec = compute_spec
self.dist_spec = dist_spec
def get_process_group(self):
return self.dist_spec.process_group
def get_placement(self):
return self.dist_spec.placement
def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.compute_spec.compute_pattern == compute_pattern
def __repr__(self):
return f'parallel action: {self.compute_spec}, dist_spec: {self.dist_spec}'
@dataclass
class ColoTensorSpec:
pg: ProcessGroup
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
compute_attr: Optional[ComputeSpec] = None