mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[refactor] move process group from _DistSpec to ColoTensor. (#1203)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import TensorSpec
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .colo_tensor import ColoTensor
|
||||
from .colo_parameter import ColoParameter
|
||||
@@ -9,7 +9,7 @@ from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||
from . import distspec
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ComputeSpec', 'named_params_with_colotensor',
|
||||
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState',
|
||||
'ProcessGroup'
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ChunkManager', 'TensorState', 'ProcessGroup',
|
||||
'ColoTensorSpec', 'TensorSpec'
|
||||
]
|
||||
|
@@ -5,7 +5,7 @@ from copy import copy
|
||||
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor.const import TensorType
|
||||
from colossalai.tensor import TensorSpec, distspec
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
def __new__(cls,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
if data is None:
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
@@ -36,11 +36,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
def __init__(self,
|
||||
data: Optional[torch.Tensor] = None,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._tensor_spec = copy(spec)
|
||||
spec: ColoTensorSpec = None) -> None:
|
||||
ColoTensor.__init__(self, data, spec)
|
||||
self._type = TensorType.MODEL
|
||||
self._graph_node = None
|
||||
|
||||
# a list contains modules sharing this ColoParameter with others.
|
||||
self._shared_param_modules = []
|
||||
|
||||
@@ -51,7 +49,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor,
|
||||
requires_grad: bool = True,
|
||||
spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoParameter':
|
||||
spec: ColoTensorSpec = None) -> 'ColoParameter':
|
||||
tensor = tensor.as_subclass(ColoParameter)
|
||||
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
|
||||
return tensor
|
||||
@@ -82,7 +80,9 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
|
||||
tensor = ColoParameter(data,
|
||||
self.requires_grad,
|
||||
spec=ColoTensorSpec(self.get_process_group(), self.dist_spec, self.compute_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
|
@@ -4,18 +4,18 @@ from copy import copy
|
||||
import torch
|
||||
from torch.overrides import get_default_nowrap_functions
|
||||
|
||||
from colossalai.tensor import TensorSpec
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
from colossalai.tensor.dist_spec_mgr import DistSpecManager
|
||||
from colossalai.tensor.distspec import _DistSpec
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _convert_output(output):
|
||||
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
|
||||
output = ColoTensor.from_torch_tensor(output)
|
||||
def _check_output(output):
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise RuntimeError
|
||||
elif isinstance(output, (list, tuple)):
|
||||
output = type(output)(_convert_output(o) for o in output)
|
||||
output = type(output)(_check_output(o) for o in output)
|
||||
return output
|
||||
|
||||
|
||||
@@ -23,28 +23,29 @@ class ColoTensor(torch.Tensor):
|
||||
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()).
|
||||
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(distspec.replicate()).
|
||||
|
||||
The signature of the function has to be consistent with the __new__ except for the 1st arg.
|
||||
The class should be initialized with a torch tensor in the following ways.
|
||||
1. directly init.
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||
>>> pg = ProcessGroup()
|
||||
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
||||
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
|
||||
>>> shard_spec = distspec.shard(process_group=ProcessGroup(tp=world_size),
|
||||
>>> dims=[0],
|
||||
>>> num_partitions=[world_size])
|
||||
>>> tensor_spec = TensorSpec(shard_spec)
|
||||
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
|
||||
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
2. use static method from_torch_tensor
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
|
||||
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg, distspec.replicate())
|
||||
"""
|
||||
|
||||
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
def __new__(cls, data: torch.Tensor, spec: ColoTensorSpec) -> 'ColoTensor':
|
||||
"""__new__
|
||||
The signature of the __new__ has to be consistent with the torch.Tensor.
|
||||
Args:
|
||||
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
||||
spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate())
|
||||
spec (TensorSpec, optional): the tensor spec of initialization.
|
||||
Returns:
|
||||
ColoTensor: a ColoTensor wrappers the data.
|
||||
"""
|
||||
@@ -52,37 +53,72 @@ class ColoTensor(torch.Tensor):
|
||||
data = torch.empty(0)
|
||||
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
||||
|
||||
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
|
||||
self._tensor_spec = copy(spec)
|
||||
def __init__(self, data: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> None:
|
||||
# If not set spec, use a DP process group and replicate dist spec
|
||||
if not spec:
|
||||
self.has_initialized = False
|
||||
self.dist_spec = distspec.replicate()
|
||||
self.compute_spec = None
|
||||
self.process_group = ProcessGroup()
|
||||
else:
|
||||
self.has_initialized = True
|
||||
self.dist_spec = spec.dist_attr
|
||||
self.compute_spec = spec.compute_attr
|
||||
self.process_group = spec.pg
|
||||
|
||||
self._type = TensorType.NONMODEL
|
||||
self._graph_node = None
|
||||
|
||||
@property
|
||||
def tensor_spec(self) -> TensorSpec:
|
||||
return self._tensor_spec
|
||||
|
||||
@tensor_spec.setter
|
||||
def tensor_spec(self, tenseor_spec: TensorSpec):
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def set_tensor_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def has_compute_spec(self) -> bool:
|
||||
return self._tensor_spec.compute_spec is not None
|
||||
return self.compute_spec is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def get_process_group(self) -> 'ProcessGroup':
|
||||
return self._tensor_spec.dist_spec.process_group
|
||||
return self.process_group
|
||||
|
||||
def set_process_group(self, pg: ProcessGroup):
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
Args:
|
||||
pg (ProcessGroup): target pg
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
RuntimeError:
|
||||
"""
|
||||
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
|
||||
if self.process_group.tp_world_size() != 1:
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group")
|
||||
|
||||
if self.dist_spec.placement.value != 'r':
|
||||
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE")
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
def get_tp_world_size(self) -> int:
|
||||
return self._tensor_spec.dist_spec.process_group.tp_world_size()
|
||||
return self.process_group.tp_world_size()
|
||||
|
||||
def set_dist_spec(self, dist_spec: _DistSpec):
|
||||
"""set_dist_spec
|
||||
set dist spec and change the payloads.
|
||||
Args:
|
||||
dist_spec (_DistSpec): target dist spec.
|
||||
"""
|
||||
assert isinstance(dist_spec, _DistSpec)
|
||||
self._convert_to_dist_spec(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
if dist_spec:
|
||||
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)}"
|
||||
self.set_dist_spec(dist_spec)
|
||||
if compute_spec:
|
||||
self.compute_spec = compute_spec
|
||||
|
||||
def has_compute_pattern(self, compute_pattern):
|
||||
return self.compute_spec.compute_pattern == compute_pattern
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
@@ -100,7 +136,9 @@ class ColoTensor(torch.Tensor):
|
||||
if func in get_default_nowrap_functions():
|
||||
return ret
|
||||
else:
|
||||
return _convert_output(ret)
|
||||
# TODO(jiaruifang) its parallel Op's duty to convert output activations
|
||||
return ret
|
||||
# return _check_output(ret)
|
||||
|
||||
def __repr__(self):
|
||||
return f'ColoTensor: {super().__repr__()}'
|
||||
@@ -113,30 +151,28 @@ class ColoTensor(torch.Tensor):
|
||||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
|
||||
self._tensor_spec.dist_spec = dist_spec
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
tensor_spec = copy(self._tensor_spec)
|
||||
tensor_spec.dist_spec = dist_spec
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
|
||||
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.dist_spec, distspec.replicate(), self.process_group)
|
||||
self.dist_spec = distspec.replicate()
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
|
||||
return self.convert_to_dist_spec(distspec.replicate())
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
tensor = tensor.as_subclass(ColoTensor)
|
||||
tensor.__init__(tensor, spec=spec)
|
||||
return tensor
|
||||
@@ -147,7 +183,7 @@ class ColoTensor(torch.Tensor):
|
||||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
|
||||
tensor = ColoTensor(data, spec=copy(ColoTensorSpec(self.process_group, self.dist_spec, self.compute_spec)))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
@@ -165,12 +201,13 @@ class ColoTensor(torch.Tensor):
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.tensor_spec.is_replicate():
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
# TODO(jiaruifang) check why this not work
|
||||
# self.data = self.to_replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate())
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, distspec.replicate(),
|
||||
self.process_group)
|
||||
self.dist_spec = distspec.replicate()
|
||||
return super().view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
@@ -179,13 +216,13 @@ class ColoTensor(torch.Tensor):
|
||||
Returns:
|
||||
ColoTensor: a tensor after viewed.
|
||||
"""
|
||||
if self.tensor_spec.is_replicate():
|
||||
if self.is_replicate():
|
||||
if args is not None:
|
||||
return super().size(args)
|
||||
else:
|
||||
return super().size()
|
||||
|
||||
spec = self.tensor_spec.dist_spec
|
||||
spec = self.dist_spec
|
||||
dims = spec.dims
|
||||
num_partitions = spec.num_partitions
|
||||
# import inspect
|
||||
@@ -198,3 +235,19 @@ class ColoTensor(torch.Tensor):
|
||||
return size_list[args]
|
||||
else:
|
||||
return torch.Size(size_list)
|
||||
|
||||
# Some API for dist spec check
|
||||
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
@@ -6,6 +6,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
@@ -29,15 +30,17 @@ def divide(numerator, denominator):
|
||||
class TransformDistSpec(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, old_dist_spec, dist_spec, forward_trans_func, backward_trans_func):
|
||||
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
|
||||
ctx.old_dist_spec = old_dist_spec
|
||||
ctx.dist_spec = dist_spec
|
||||
ctx.backward_trans_func = backward_trans_func
|
||||
return forward_trans_func(tensor, old_dist_spec, dist_spec)
|
||||
ctx.pg = pg
|
||||
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec), None, None, None, None
|
||||
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
|
||||
ctx.pg), None, None, None, None, None
|
||||
|
||||
|
||||
class DistSpecManager:
|
||||
@@ -46,18 +49,17 @@ class DistSpecManager:
|
||||
|
||||
@staticmethod
|
||||
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
|
||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
||||
and dist_spec.process_group is not None:
|
||||
raise NotImplementedError
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
"""_shard_as: shard the tensor w.r.t a distributed specification.
|
||||
Assuming the tensor passed in is a global (replicated) tensor.
|
||||
Args:
|
||||
tensor (torch.Tensor): a global (replicated) tensor before shard
|
||||
dist_spec (_DistSpec): the distributed spec. to be sharded as.
|
||||
|
||||
pg (ProcessGrouo): the process group of the corresponding colotensor
|
||||
Returns:
|
||||
torch.Tensor: a torch tensor after sharded.
|
||||
"""
|
||||
@@ -65,7 +67,7 @@ class DistSpecManager:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
|
||||
chunk = tensor
|
||||
idx = dist_spec.process_group.tp_local_rank()
|
||||
idx = pg.tp_local_rank()
|
||||
num_parts = prod(dist_spec.num_partitions)
|
||||
for i, dim in enumerate(dist_spec.dims):
|
||||
num_parts //= dist_spec.num_partitions[i]
|
||||
@@ -76,7 +78,7 @@ class DistSpecManager:
|
||||
return chunk.clone().detach().contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
"""_gather gather sharded tensors to a replicated one.
|
||||
Args:
|
||||
tensor (torch.Tensor): a shared torch tensor
|
||||
@@ -92,9 +94,9 @@ class DistSpecManager:
|
||||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
|
||||
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
|
||||
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
new_buffer = []
|
||||
dim = old_dist_spec.dims[i]
|
||||
@@ -109,12 +111,14 @@ class DistSpecManager:
|
||||
return buffer[0]
|
||||
|
||||
@staticmethod
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
world_size = old_dist_spec.process_group.tp_world_size()
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
|
||||
assert tensor.device.type == "cuda", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
@@ -126,46 +130,50 @@ class DistSpecManager:
|
||||
|
||||
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
|
||||
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
||||
|
||||
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
||||
return output_
|
||||
|
||||
@staticmethod
|
||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
return DistSpecManager._gather(tensor, old_dist_spec)
|
||||
return DistSpecManager._gather(tensor, old_dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
if old_dist_spec == dist_spec:
|
||||
return tensor
|
||||
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
|
||||
# use all-to-all to save memory
|
||||
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec)
|
||||
tensor = DistSpecManager._gather(tensor, old_dist_spec)
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec)
|
||||
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
|
||||
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
|
||||
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
|
||||
forward_trans_handle = getattr(DistSpecManager, f'_{old_dist_spec.placement.value}2{dist_spec.placement.value}')
|
||||
if not DistSpecManager._use_autograd_function:
|
||||
return forward_trans_handle(tensor, old_dist_spec, dist_spec)
|
||||
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
|
||||
backward_trans_handle = getattr(DistSpecManager,
|
||||
f'_{dist_spec.placement.value}2{old_dist_spec.placement.value}')
|
||||
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, forward_trans_handle, backward_trans_handle)
|
||||
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
|
||||
backward_trans_handle)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
|
@@ -1,7 +1,5 @@
|
||||
from enum import Enum
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from typing import Optional, List
|
||||
from numpy import prod
|
||||
from typing import List
|
||||
|
||||
__all__ = ['replicate', 'shard']
|
||||
|
||||
@@ -13,10 +11,7 @@ class DistPlacementPattern(Enum):
|
||||
|
||||
class _DistSpec:
|
||||
|
||||
def __init__(self,
|
||||
dist_placement_pattern: DistPlacementPattern,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
**meta_info):
|
||||
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
|
||||
"""_DistSpec, Distributed Specification
|
||||
|
||||
Args:
|
||||
@@ -25,7 +20,6 @@ class _DistSpec:
|
||||
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||
"""
|
||||
self.placement = dist_placement_pattern
|
||||
self.process_group = process_group
|
||||
for k, v in meta_info.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
@@ -45,14 +39,11 @@ class _DistSpec:
|
||||
return res
|
||||
|
||||
|
||||
def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
||||
# process_group=None means global process group
|
||||
return _DistSpec(DistPlacementPattern.REPLICATE, process_group)
|
||||
def replicate() -> _DistSpec:
|
||||
return _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
|
||||
|
||||
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
assert process_group is not None and isinstance(process_group, ProcessGroup)
|
||||
def shard(dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
|
||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
|
@@ -3,6 +3,7 @@ from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Any
|
||||
from colossalai.tensor.colo_tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
|
||||
|
||||
class ParamOpHook(ABC):
|
||||
@@ -129,7 +130,7 @@ def _get_colo_tensors_info(*args) -> list:
|
||||
info = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
info.append((arg.__class__, arg.tensor_spec))
|
||||
info.append((arg.__class__, ColoTensorSpec(arg.get_process_group(), arg.dist_spec, arg.compute_spec)))
|
||||
else:
|
||||
info.append(None)
|
||||
return info
|
||||
|
@@ -20,6 +20,9 @@ class ProcessGroup:
|
||||
ranks: Optional[List[int]] = None,
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
if not torch.distributed.is_initialized():
|
||||
return
|
||||
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
if rank is None:
|
||||
self._rank = torch.distributed.get_rank()
|
||||
|
@@ -1,44 +1,12 @@
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||
from .compute_spec import ComputeSpec, ComputePattern
|
||||
from .compute_spec import ComputeSpec
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TensorSpec(object):
|
||||
"""
|
||||
The specification of the ColoTensor.
|
||||
Args:
|
||||
dist_spec (_DistSpec): descriping the layout among processes.
|
||||
compute_spec (Optional[ComputeSpec], optional): actions conducted on the tensor after initialization if it's a model data tensor.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dist_spec: _DistSpec, compute_spec: Optional[ComputeSpec] = None):
|
||||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
||||
def is_replicate(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
||||
|
||||
def is_shard_1drow(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return self.compute_spec.compute_pattern == compute_pattern
|
||||
|
||||
def __repr__(self):
|
||||
return f'parallel action: {self.compute_spec}, dist_spec: {self.dist_spec}'
|
||||
@dataclass
|
||||
class ColoTensorSpec:
|
||||
pg: ProcessGroup
|
||||
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
compute_attr: Optional[ComputeSpec] = None
|
||||
|
Reference in New Issue
Block a user