mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
@@ -1,18 +1,11 @@
|
||||
from . import distspec
|
||||
from .colo_parameter import ColoParameter
|
||||
from .colo_tensor import ColoTensor
|
||||
from .comm_spec import CollectiveCommPattern, CommSpec
|
||||
from .compute_spec import ComputePattern, ComputeSpec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .distspec import ReplicaSpec, ShardSpec
|
||||
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
|
||||
from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
|
||||
'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
|
||||
'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
|
||||
'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook',
|
||||
'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
|
||||
'merge_same_dim_mesh_list'
|
||||
]
|
||||
|
@@ -1,29 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ComputePattern(Enum):
|
||||
TP1D = 0
|
||||
TP2D = 1
|
||||
TP2P5D = 2
|
||||
TP3D = 3
|
||||
|
||||
|
||||
class ComputeSpec(object):
|
||||
"""ComputeSpec
|
||||
The Specification for computation pattern
|
||||
|
||||
Args:
|
||||
compute_pattern (ComputePattern): an Enum instance for compute pattern.
|
||||
"""
|
||||
|
||||
def __init__(self, compute_pattern: ComputePattern) -> None:
|
||||
assert isinstance(compute_pattern, ComputePattern)
|
||||
self.compute_pattern = compute_pattern
|
||||
# Make sure output tensors are replicate
|
||||
self.output_replicate = True
|
||||
|
||||
def __repr__(self):
|
||||
return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
|
||||
|
||||
def set_output_replicate(self, flag: bool = True):
|
||||
self.output_replicate = flag
|
@@ -1,6 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TensorType(Enum):
|
||||
MODEL = 0
|
||||
NONMODEL = 1 # mainly activations
|
@@ -1,196 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from numpy import prod
|
||||
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
# colossalai.tensor shall not import any submodule from colossal.nn
|
||||
def divide(numerator, denominator):
|
||||
"""Only allow exact division.
|
||||
|
||||
Args:
|
||||
numerator (int): Numerator of the division.
|
||||
denominator (int): Denominator of the division.
|
||||
|
||||
Returns:
|
||||
int: the result of exact division.
|
||||
"""
|
||||
assert denominator != 0, 'denominator can not be zero'
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class TransformDistSpec(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
|
||||
ctx.old_dist_spec = old_dist_spec
|
||||
ctx.dist_spec = dist_spec
|
||||
ctx.backward_trans_func = backward_trans_func
|
||||
ctx.pg = pg
|
||||
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
|
||||
ctx.pg), None, None, None, None, None
|
||||
|
||||
|
||||
class DistSpecManager:
|
||||
|
||||
_use_autograd_function: bool = True
|
||||
|
||||
@staticmethod
|
||||
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
"""_shard_as: shard the tensor w.r.t a distributed specification.
|
||||
Assuming the tensor passed in is a global (replicated) tensor.
|
||||
Args:
|
||||
tensor (torch.Tensor): a global (replicated) tensor before shard
|
||||
dist_spec (_DistSpec): the distributed spec. to be sharded as.
|
||||
pg (ProcessGroup): the process group of the corresponding colotensor
|
||||
Returns:
|
||||
torch.Tensor: a torch tensor after sharded.
|
||||
"""
|
||||
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
|
||||
chunk = tensor
|
||||
idx = pg.tp_local_rank()
|
||||
num_parts = prod(dist_spec.num_partitions)
|
||||
for i, dim in enumerate(dist_spec.dims):
|
||||
num_parts //= dist_spec.num_partitions[i]
|
||||
|
||||
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
|
||||
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
|
||||
idx %= num_parts
|
||||
return chunk.clone().detach().contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
"""_gather gather sharded tensors to a replicated one.
|
||||
Args:
|
||||
tensor (torch.Tensor): a shared torch tensor
|
||||
old_dist_spec (_DistSpec): the distributed spec. of the tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: a replicated tensor.
|
||||
"""
|
||||
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
||||
is_cpu_tensor = False
|
||||
if tensor.device.type == 'cpu':
|
||||
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
||||
# Therefore, we transfer tensor to GPU before gather.
|
||||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
is_cpu_tensor = True
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
new_buffer = []
|
||||
dim = old_dist_spec.dims[i]
|
||||
num_parts = old_dist_spec.num_partitions[i]
|
||||
for start in range(0, len(buffer), num_parts):
|
||||
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
|
||||
buffer = new_buffer
|
||||
assert len(buffer) == 1
|
||||
|
||||
if is_cpu_tensor:
|
||||
buffer[0].data = buffer[0].data.to(saved_dev)
|
||||
return buffer[0]
|
||||
|
||||
@staticmethod
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
shapes = list(tensor.shape)
|
||||
scattered_dim_size = shapes[scatter_dim] // world_size
|
||||
gathered_dim_size = shapes[gather_dim] * world_size
|
||||
shapes[scatter_dim] = scattered_dim_size
|
||||
|
||||
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
|
||||
|
||||
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
||||
return output_
|
||||
|
||||
@staticmethod
|
||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
return tensor
|
||||
|
||||
@staticmethod
|
||||
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
return DistSpecManager._gather(tensor, old_dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
if old_dist_spec == dist_spec:
|
||||
return tensor
|
||||
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
|
||||
# use all-to-all to save memory
|
||||
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
|
||||
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
|
||||
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
|
||||
|
||||
trans_func_key = (old_dist_spec.placement, dist_spec.placement)
|
||||
trans_funcs = {
|
||||
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
|
||||
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
|
||||
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
|
||||
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s
|
||||
}
|
||||
|
||||
forward_trans_handle = trans_funcs[trans_func_key]
|
||||
if not DistSpecManager._use_autograd_function:
|
||||
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
|
||||
|
||||
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
|
||||
backward_trans_handle)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def no_grad():
|
||||
try:
|
||||
DistSpecManager._use_autograd_function = False
|
||||
yield
|
||||
finally:
|
||||
DistSpecManager._use_autograd_function = True
|
@@ -1,78 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
__all__ = ['ReplicaSpec', 'ShardSpec']
|
||||
|
||||
|
||||
class DistPlacementPattern(Enum):
|
||||
REPLICATE = 'r'
|
||||
SHARD = 's'
|
||||
|
||||
|
||||
class _DistSpec:
|
||||
"""_DistSpec
|
||||
|
||||
A class indicates Distributed Specification.
|
||||
The DistSpec is only works for the tensor parallel process groups.
|
||||
Because the dist spec of data parallel process group can be automatically deduced.
|
||||
This is an internal data structure.
|
||||
The API for users should be `ShardSpec` and `ReplicaSpec`.
|
||||
|
||||
Args:
|
||||
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
|
||||
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
|
||||
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
|
||||
|
||||
self.placement = dist_placement_pattern
|
||||
for k, v in meta_info.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def __eq__(self, other: "_DistSpec") -> bool:
|
||||
if dir(self) != dir(other):
|
||||
return False
|
||||
for attr in dir(self):
|
||||
if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attr_list = []
|
||||
for attr in dir(self):
|
||||
if not attr.startswith('__'):
|
||||
attr_list.append(f'{attr}={str(getattr(self, attr))}')
|
||||
attr_str = ", ".join(attr_list)
|
||||
return "DistSpec(" + attr_str + ")"
|
||||
|
||||
|
||||
def ReplicaSpec() -> _DistSpec:
|
||||
"""ReplicaSpec
|
||||
|
||||
A distributed specification represents the tensor is replicated among the tensor parallel process group.
|
||||
|
||||
Returns:
|
||||
_DistSpec: an replicated dist spec instance.
|
||||
"""
|
||||
return _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
|
||||
|
||||
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
"""ShardSpec
|
||||
|
||||
A distributed specification represents the tensor is sharded among the tensor parallel process group.
|
||||
|
||||
Note:
|
||||
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
|
||||
|
||||
Args:
|
||||
dims (List[int]): a list of dimensions
|
||||
num_partitions (List[int]): a list of partition number of each dimensions.
|
||||
|
||||
Returns:
|
||||
_DistSpec: an shard dist spec instance.
|
||||
"""
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
@@ -1,53 +0,0 @@
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
)
|
||||
import functools
|
||||
|
||||
# Custom sharded ops
|
||||
_COLOSSAL_OPS: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def _register_colo_op(op, func):
|
||||
global _COLOSSAL_OPS
|
||||
_COLOSSAL_OPS[op] = func
|
||||
|
||||
|
||||
def colo_op_impl(func):
|
||||
"""
|
||||
Provides a way for users to write their own custom operator. This
|
||||
can be used to override existing ColoTensor operators or write a new
|
||||
one not supported by ColoTensor. If the operator in question is covered
|
||||
by ``__torch_function__`` dispatch and has a ColoTensor as any of its
|
||||
parameters, the function provided will be invoked for that operator.
|
||||
|
||||
Example:
|
||||
>>> @colo_op_impl(torch.nn.functional.linear)
|
||||
>>> def my_custom_linear(types, args, kwargs, process_group):
|
||||
>>> ....
|
||||
>>>
|
||||
>>> input = torch.rand(10, 32)
|
||||
>>> weight = ColoTensor(torch.rand(32, 16))
|
||||
>>> bias = ColoTensor(torch.rand(16))
|
||||
>>> # This will call `my_custom_linear` instead of the default.
|
||||
>>> torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
The types, args and kwargs parameters are the same parameters that are
|
||||
passed to ``__torch_function__`` dispatch API
|
||||
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
|
||||
|
||||
Args:
|
||||
func(Callable): Torch function for which we want to provide a sharded
|
||||
implementation (ex: torch.nn.functional.linear)
|
||||
"""
|
||||
|
||||
def decorator_sharded_func(wrapped_func):
|
||||
_register_colo_op(func, wrapped_func)
|
||||
|
||||
@functools.wraps(wrapped_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return wrapped_func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator_sharded_func
|
@@ -1,319 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
# distributed settings
|
||||
# use this dict to record all Pytorch ProcessGroups
|
||||
self.dict = {}
|
||||
# set a distributed logger
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
|
||||
def log_pg_init(self, rank_list: List[int], backend: str):
|
||||
str_list = ["Pytorch ProcessGroup Init:"]
|
||||
str_list.append(f"backend: {backend}")
|
||||
str_list.append(f"ranks: {rank_list}")
|
||||
self.logger.info("\n\t".join(str_list), ranks=[0])
|
||||
|
||||
def get(self, rank_list: List[int], backend: str = 'nccl'):
|
||||
"""Reuse Pytorch ProcessGroup when such a group is initialized
|
||||
"""
|
||||
# we need to convert the passed list to a tuple
|
||||
# since List is unhashable
|
||||
processgroup_key = (backend, tuple(rank_list))
|
||||
if processgroup_key not in self.dict:
|
||||
self.log_pg_init(rank_list=rank_list, backend=backend)
|
||||
self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
||||
return self.dict[processgroup_key]
|
||||
|
||||
|
||||
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
"""ProcessGroup
|
||||
Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism.
|
||||
|
||||
NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`
|
||||
|
||||
|
||||
Args:
|
||||
rank: the global rank of the current process.
|
||||
ranks: List[int], a list of rank id belongings to this process group.
|
||||
backend: str, the backend of the process group.
|
||||
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
|
||||
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
if not torch.distributed.is_initialized():
|
||||
self.is_init = False
|
||||
return
|
||||
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
|
||||
self._rank = torch.distributed.get_rank()
|
||||
if rank is not None:
|
||||
assert self._rank == rank # make sure that the global rank is correct
|
||||
|
||||
if ranks is None:
|
||||
self._rank_list = list(range(torch.distributed.get_world_size()))
|
||||
else:
|
||||
self._rank_list = ranks
|
||||
self._rank_list.sort() # ensure that the list is in order
|
||||
|
||||
self._world_size = len(self._rank_list)
|
||||
|
||||
if dp_degree is None and tp_degree is None:
|
||||
self._dp_degree = self._world_size
|
||||
self._tp_degree = 1
|
||||
elif dp_degree and not tp_degree:
|
||||
self._dp_degree = dp_degree
|
||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
self._tp_degree = self._world_size // dp_degree
|
||||
elif not dp_degree and tp_degree:
|
||||
self._tp_degree = tp_degree
|
||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
self._dp_degree = self._world_size // tp_degree
|
||||
else:
|
||||
self._dp_degree = dp_degree
|
||||
self._tp_degree = tp_degree
|
||||
assert self._dp_degree * self._tp_degree == self._world_size, \
|
||||
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
|
||||
f"and TP degree {self._tp_degree}"
|
||||
|
||||
self._tp_rank_list = None
|
||||
self._dp_rank_list = None
|
||||
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
|
||||
if self._rank in i_tp_list:
|
||||
self._tp_rank_list = i_tp_list
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
|
||||
if self._rank in j_dp_list:
|
||||
self._dp_rank_list = j_dp_list
|
||||
|
||||
self._has_cpu_groups = False
|
||||
self.is_init = True
|
||||
|
||||
def set_cpu_groups(self):
|
||||
"""set_cpu_groups
|
||||
Initialize Pytorch process groups for cpu communications.
|
||||
"""
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'gloo')
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'gloo')
|
||||
|
||||
self._has_cpu_groups = True
|
||||
|
||||
@property
|
||||
def has_cpu_groups(self) -> bool:
|
||||
"""has_cpu_groups
|
||||
If cpu groups have been initialized.
|
||||
|
||||
Returns:
|
||||
bool: cpu process groups have been initialized or not.
|
||||
"""
|
||||
return self._has_cpu_groups
|
||||
|
||||
def __repr__(self):
|
||||
if self.is_init:
|
||||
ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
|
||||
personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
|
||||
return ranks_str + personal_str
|
||||
else:
|
||||
return "ProcessGroup not initialized"
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
return False
|
||||
if self._rank_list != obj._rank_list:
|
||||
return False
|
||||
if self._tp_rank_list != obj._tp_rank_list:
|
||||
return False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
return False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
return False
|
||||
return True
|
||||
|
||||
def rank(self) -> int:
|
||||
"""rank
|
||||
|
||||
The current rank in the global process group.
|
||||
|
||||
Returns:
|
||||
int: the rank number
|
||||
"""
|
||||
return self._rank
|
||||
|
||||
def ranks_in_group(self) -> List[int]:
|
||||
"""ranks_in_group
|
||||
|
||||
a list of rank number in in the global process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
"""
|
||||
return self._rank_list
|
||||
|
||||
def world_size(self) -> int:
|
||||
"""world_size
|
||||
|
||||
The world size of the global process group.
|
||||
|
||||
Returns:
|
||||
int: world size
|
||||
"""
|
||||
return self._world_size
|
||||
|
||||
def tp_rank_list(self) -> List[int]:
|
||||
"""tp_rank_list
|
||||
|
||||
the rank list in the TP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
List[int]: the list of rank number.
|
||||
"""
|
||||
return self._tp_rank_list
|
||||
|
||||
def dp_rank_list(self) -> List[int]:
|
||||
"""dp_rank_list
|
||||
|
||||
the rank list in the DP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
List[int]: the list of rank number.
|
||||
"""
|
||||
return self._dp_rank_list
|
||||
|
||||
def tp_local_rank(self) -> int:
|
||||
"""tp_local_rank
|
||||
|
||||
The local rank number in the current TP process group.
|
||||
|
||||
Returns:
|
||||
int: tp rank number.
|
||||
"""
|
||||
return self._rank % self._tp_degree
|
||||
|
||||
def dp_local_rank(self) -> int:
|
||||
"""dp_local_rank
|
||||
|
||||
The local rank number in the current DP process group.
|
||||
|
||||
Returns:
|
||||
int: dp rank number.
|
||||
"""
|
||||
return self._rank // self._tp_degree
|
||||
|
||||
def dp_world_size(self) -> int:
|
||||
"""dp_world_size
|
||||
|
||||
The world size of the current DP process group.
|
||||
|
||||
Returns:
|
||||
int: dp world size
|
||||
"""
|
||||
return len(self._dp_rank_list)
|
||||
|
||||
def tp_world_size(self) -> int:
|
||||
"""tp_world_size
|
||||
|
||||
The world size of the current TP process group.
|
||||
|
||||
Returns:
|
||||
int: tp world size
|
||||
"""
|
||||
return len(self._tp_rank_list)
|
||||
|
||||
def dp_process_group(self):
|
||||
"""dp_process_group
|
||||
|
||||
the pytorch DP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
def tp_process_group(self):
|
||||
"""tp_process_group
|
||||
|
||||
the pytorch TP process group containing the current rank.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
"""cpu_dp_process_group
|
||||
|
||||
the pytorch CPU DP process group containing the current rank.
|
||||
|
||||
assert failed if cpu process group is not initialized.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
"""cpu_tp_process_group
|
||||
|
||||
the pytorch CPU TP process group containing the current rank.
|
||||
|
||||
assert failed if cpu process group is not initialized.
|
||||
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
|
||||
def get_ranks_in_dp(self) -> List[int]:
|
||||
"""get_ranks_in_dp
|
||||
|
||||
ranks in current dp process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
"""
|
||||
return self._dp_rank_list
|
||||
|
||||
def get_ranks_in_tp(self):
|
||||
"""get_ranks_in_tp
|
||||
|
||||
ranks in current tp process group.
|
||||
|
||||
Returns:
|
||||
List[int]: a list of rank number.
|
||||
"""
|
||||
return self._tp_rank_list
|
@@ -1,20 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
|
||||
from colossalai.tensor.process_group import ProcessGroup
|
||||
|
||||
from .compute_spec import ComputeSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class ColoTensorSpec:
|
||||
""" ColoTensorSpec
|
||||
|
||||
A data class for specifications of the `ColoTensor`.
|
||||
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||
"""
|
||||
pg: ProcessGroup
|
||||
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
compute_attr: Optional[ComputeSpec] = None
|
Reference in New Issue
Block a user