[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -1,18 +1,11 @@
from . import distspec
from .colo_parameter import ColoParameter
from .colo_tensor import ColoTensor
from .comm_spec import CollectiveCommPattern, CommSpec
from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ColoParamOpHook, ColoParamOpHookManager
from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ColoParamOpHook', 'ColoParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec',
'ShardSpec', 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
'ColoTensor', 'convert_parameter', 'named_params_with_colotensor', 'ColoParameter', 'ColoParamOpHook',
'ColoParamOpHookManager', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict',
'merge_same_dim_mesh_list'
]

View File

@@ -1,29 +0,0 @@
from enum import Enum
class ComputePattern(Enum):
TP1D = 0
TP2D = 1
TP2P5D = 2
TP3D = 3
class ComputeSpec(object):
"""ComputeSpec
The Specification for computation pattern
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
def __init__(self, compute_pattern: ComputePattern) -> None:
assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern
# Make sure output tensors are replicate
self.output_replicate = True
def __repr__(self):
return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag

View File

@@ -1,6 +0,0 @@
from enum import Enum
class TensorType(Enum):
MODEL = 0
NONMODEL = 1 # mainly activations

View File

@@ -1,196 +0,0 @@
from contextlib import contextmanager
import torch
import torch.distributed as dist
from numpy import prod
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
# colossalai.tensor shall not import any submodule from colossal.nn
def divide(numerator, denominator):
"""Only allow exact division.
Args:
numerator (int): Numerator of the division.
denominator (int): Denominator of the division.
Returns:
int: the result of exact division.
"""
assert denominator != 0, 'denominator can not be zero'
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
class TransformDistSpec(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
ctx.old_dist_spec = old_dist_spec
ctx.dist_spec = dist_spec
ctx.backward_trans_func = backward_trans_func
ctx.pg = pg
return forward_trans_func(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def backward(ctx, grad_outputs):
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
ctx.pg), None, None, None, None, None
class DistSpecManager:
_use_autograd_function: bool = True
@staticmethod
def _sanity_check(old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> None:
pass
@staticmethod
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
"""_shard_as: shard the tensor w.r.t a distributed specification.
Assuming the tensor passed in is a global (replicated) tensor.
Args:
tensor (torch.Tensor): a global (replicated) tensor before shard
dist_spec (_DistSpec): the distributed spec. to be sharded as.
pg (ProcessGroup): the process group of the corresponding colotensor
Returns:
torch.Tensor: a torch tensor after sharded.
"""
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = pg.tp_local_rank()
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i])
chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size)
idx %= num_parts
return chunk.clone().detach().contiguous()
@staticmethod
def _gather(tensor: torch.Tensor, old_dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
"""_gather gather sharded tensors to a replicated one.
Args:
tensor (torch.Tensor): a shared torch tensor
old_dist_spec (_DistSpec): the distributed spec. of the tensor.
Returns:
torch.Tensor: a replicated tensor.
"""
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
is_cpu_tensor = False
if tensor.device.type == 'cpu':
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
is_cpu_tensor = True
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == 'cuda'
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
num_parts = old_dist_spec.num_partitions[i]
for start in range(0, len(buffer), num_parts):
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
buffer = new_buffer
assert len(buffer) == 1
if is_cpu_tensor:
buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0]
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device"
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]
shapes = list(tensor.shape)
scattered_dim_size = shapes[scatter_dim] // world_size
gathered_dim_size = shapes[gather_dim] * world_size
shapes[scatter_dim] = scattered_dim_size
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=pg.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
return output_
@staticmethod
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return tensor
@staticmethod
def _r2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
return DistSpecManager._gather(tensor, old_dist_spec, pg)
@staticmethod
def _s2s(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup) -> torch.Tensor:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
if old_dist_spec == dist_spec:
return tensor
if len(old_dist_spec.dims) == 1 and len(dist_spec.dims) == 1:
# use all-to-all to save memory
return DistSpecManager._all_to_all(tensor, old_dist_spec, dist_spec, pg)
tensor = DistSpecManager._gather(tensor, old_dist_spec, pg)
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
trans_func_key = (old_dist_spec.placement, dist_spec.placement)
trans_funcs = {
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s
}
forward_trans_handle = trans_funcs[trans_func_key]
if not DistSpecManager._use_autograd_function:
return forward_trans_handle(tensor, old_dist_spec, dist_spec, pg)
backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
backward_trans_handle)
@staticmethod
@contextmanager
def no_grad():
try:
DistSpecManager._use_autograd_function = False
yield
finally:
DistSpecManager._use_autograd_function = True

View File

@@ -1,78 +0,0 @@
from enum import Enum
from typing import List
__all__ = ['ReplicaSpec', 'ShardSpec']
class DistPlacementPattern(Enum):
REPLICATE = 'r'
SHARD = 's'
class _DistSpec:
"""_DistSpec
A class indicates Distributed Specification.
The DistSpec is only works for the tensor parallel process groups.
Because the dist spec of data parallel process group can be automatically deduced.
This is an internal data structure.
The API for users should be `ShardSpec` and `ReplicaSpec`.
Args:
dist_placement_pattern (DistPlacementPattern): the pattern describing how tensors are distributed among processes.
The dist_placement_pattern is picked from a limited set, now including two patterns: replicate and shard.
process_group (Optional[ProcessGroup], optional): the process group contains processes. Defaults to None.
"""
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
self.placement = dist_placement_pattern
for k, v in meta_info.items():
setattr(self, k, v)
def __eq__(self, other: "_DistSpec") -> bool:
if dir(self) != dir(other):
return False
for attr in dir(self):
if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr):
return False
return True
def __repr__(self) -> str:
attr_list = []
for attr in dir(self):
if not attr.startswith('__'):
attr_list.append(f'{attr}={str(getattr(self, attr))}')
attr_str = ", ".join(attr_list)
return "DistSpec(" + attr_str + ")"
def ReplicaSpec() -> _DistSpec:
"""ReplicaSpec
A distributed specification represents the tensor is replicated among the tensor parallel process group.
Returns:
_DistSpec: an replicated dist spec instance.
"""
return _DistSpec(DistPlacementPattern.REPLICATE)
def ShardSpec(dims: List[int], num_partitions: List[int]) -> _DistSpec:
"""ShardSpec
A distributed specification represents the tensor is sharded among the tensor parallel process group.
Note:
Currently, only shard on one dimension is valid. In another word, dims should be of size 1.
Args:
dims (List[int]): a list of dimensions
num_partitions (List[int]): a list of partition number of each dimensions.
Returns:
_DistSpec: an shard dist spec instance.
"""
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
return _DistSpec(DistPlacementPattern.SHARD, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@@ -1,53 +0,0 @@
from typing import (
Callable,
Dict,
)
import functools
# Custom sharded ops
_COLOSSAL_OPS: Dict[str, Callable] = {}
def _register_colo_op(op, func):
global _COLOSSAL_OPS
_COLOSSAL_OPS[op] = func
def colo_op_impl(func):
"""
Provides a way for users to write their own custom operator. This
can be used to override existing ColoTensor operators or write a new
one not supported by ColoTensor. If the operator in question is covered
by ``__torch_function__`` dispatch and has a ColoTensor as any of its
parameters, the function provided will be invoked for that operator.
Example:
>>> @colo_op_impl(torch.nn.functional.linear)
>>> def my_custom_linear(types, args, kwargs, process_group):
>>> ....
>>>
>>> input = torch.rand(10, 32)
>>> weight = ColoTensor(torch.rand(32, 16))
>>> bias = ColoTensor(torch.rand(16))
>>> # This will call `my_custom_linear` instead of the default.
>>> torch.nn.functional.linear(input, weight, bias)
The types, args and kwargs parameters are the same parameters that are
passed to ``__torch_function__`` dispatch API
(https://pytorch.org/docs/stable/notes/extending.html#extending-torch).
Args:
func(Callable): Torch function for which we want to provide a sharded
implementation (ex: torch.nn.functional.linear)
"""
def decorator_sharded_func(wrapped_func):
_register_colo_op(func, wrapped_func)
@functools.wraps(wrapped_func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)
return wrapper
return decorator_sharded_func

View File

@@ -1,319 +0,0 @@
from typing import List, Optional
import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.logging import get_dist_logger
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self):
# distributed settings
# use this dict to record all Pytorch ProcessGroups
self.dict = {}
# set a distributed logger
self.logger = get_dist_logger('ProcessGroup')
def log_pg_init(self, rank_list: List[int], backend: str):
str_list = ["Pytorch ProcessGroup Init:"]
str_list.append(f"backend: {backend}")
str_list.append(f"ranks: {rank_list}")
self.logger.info("\n\t".join(str_list), ranks=[0])
def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
# we need to convert the passed list to a tuple
# since List is unhashable
processgroup_key = (backend, tuple(rank_list))
if processgroup_key not in self.dict:
self.log_pg_init(rank_list=rank_list, backend=backend)
self.dict[processgroup_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[processgroup_key]
PYTORCHPGDICT_ = PyTorchProcessGroupDict()
class ProcessGroup:
"""ProcessGroup
Process Group indicates how processes are organized in groups for parallel execution using Tensor Parallelism and Data Parallelism.
NOTE, the ProcessGroup must be used after `torch.distributed.initialize()`
Args:
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
backend: str, the backend of the process group.
tp_degree: Optional[int], tensor parallelism degree. How many processes are inside a tp process group. default None means 1.
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
def __init__(self,
rank: Optional[int] = None,
ranks: Optional[List[int]] = None,
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
if not torch.distributed.is_initialized():
self.is_init = False
return
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
self._rank = torch.distributed.get_rank()
if rank is not None:
assert self._rank == rank # make sure that the global rank is correct
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
else:
self._rank_list = ranks
self._rank_list.sort() # ensure that the list is in order
self._world_size = len(self._rank_list)
if dp_degree is None and tp_degree is None:
self._dp_degree = self._world_size
self._tp_degree = 1
elif dp_degree and not tp_degree:
self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size // dp_degree
elif not dp_degree and tp_degree:
self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size // tp_degree
else:
self._dp_degree = dp_degree
self._tp_degree = tp_degree
assert self._dp_degree * self._tp_degree == self._world_size, \
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
f"and TP degree {self._tp_degree}"
self._tp_rank_list = None
self._dp_rank_list = None
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
if self._rank in i_tp_list:
self._tp_rank_list = i_tp_list
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
if self._rank in j_dp_list:
self._dp_rank_list = j_dp_list
self._has_cpu_groups = False
self.is_init = True
def set_cpu_groups(self):
"""set_cpu_groups
Initialize Pytorch process groups for cpu communications.
"""
if self.has_cpu_groups:
return
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'gloo')
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'gloo')
self._has_cpu_groups = True
@property
def has_cpu_groups(self) -> bool:
"""has_cpu_groups
If cpu groups have been initialized.
Returns:
bool: cpu process groups have been initialized or not.
"""
return self._has_cpu_groups
def __repr__(self):
if self.is_init:
ranks_str = f"ProcessGroup(ranks={self._rank_list},\n"
personal_str = f" rank={self._rank}, dp={self._dp_degree}, tp={self._tp_degree})"
return ranks_str + personal_str
else:
return "ProcessGroup not initialized"
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
return False
if self._rank_list != obj._rank_list:
return False
if self._tp_rank_list != obj._tp_rank_list:
return False
if self._dp_rank_list != obj._dp_rank_list:
return False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree:
return False
return True
def rank(self) -> int:
"""rank
The current rank in the global process group.
Returns:
int: the rank number
"""
return self._rank
def ranks_in_group(self) -> List[int]:
"""ranks_in_group
a list of rank number in in the global process group.
Returns:
List[int]: a list of rank number.
"""
return self._rank_list
def world_size(self) -> int:
"""world_size
The world size of the global process group.
Returns:
int: world size
"""
return self._world_size
def tp_rank_list(self) -> List[int]:
"""tp_rank_list
the rank list in the TP process group containing the current rank.
Returns:
List[int]: the list of rank number.
"""
return self._tp_rank_list
def dp_rank_list(self) -> List[int]:
"""dp_rank_list
the rank list in the DP process group containing the current rank.
Returns:
List[int]: the list of rank number.
"""
return self._dp_rank_list
def tp_local_rank(self) -> int:
"""tp_local_rank
The local rank number in the current TP process group.
Returns:
int: tp rank number.
"""
return self._rank % self._tp_degree
def dp_local_rank(self) -> int:
"""dp_local_rank
The local rank number in the current DP process group.
Returns:
int: dp rank number.
"""
return self._rank // self._tp_degree
def dp_world_size(self) -> int:
"""dp_world_size
The world size of the current DP process group.
Returns:
int: dp world size
"""
return len(self._dp_rank_list)
def tp_world_size(self) -> int:
"""tp_world_size
The world size of the current TP process group.
Returns:
int: tp world size
"""
return len(self._tp_rank_list)
def dp_process_group(self):
"""dp_process_group
the pytorch DP process group containing the current rank.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def tp_process_group(self):
"""tp_process_group
the pytorch TP process group containing the current rank.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
def cpu_dp_process_group(self):
"""cpu_dp_process_group
the pytorch CPU DP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
def cpu_tp_process_group(self):
"""cpu_tp_process_group
the pytorch CPU TP process group containing the current rank.
assert failed if cpu process group is not initialized.
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
def get_ranks_in_dp(self) -> List[int]:
"""get_ranks_in_dp
ranks in current dp process group.
Returns:
List[int]: a list of rank number.
"""
return self._dp_rank_list
def get_ranks_in_tp(self):
"""get_ranks_in_tp
ranks in current tp process group.
Returns:
List[int]: a list of rank number.
"""
return self._tp_rank_list

View File

@@ -1,20 +0,0 @@
from dataclasses import dataclass
from typing import Optional
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from .compute_spec import ComputeSpec
@dataclass
class ColoTensorSpec:
""" ColoTensorSpec
A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
"""
pg: ProcessGroup
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
compute_attr: Optional[ComputeSpec] = None