mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -6,12 +6,12 @@ from .process_group import ProcessGroup
|
||||
from .tensor_spec import ColoTensorSpec
|
||||
|
||||
__all__ = [
|
||||
'ComputePattern',
|
||||
'ComputeSpec',
|
||||
'distspec',
|
||||
'DistSpecManager',
|
||||
'ProcessGroup',
|
||||
'ColoTensorSpec',
|
||||
'ShardSpec',
|
||||
'ReplicaSpec',
|
||||
"ComputePattern",
|
||||
"ComputeSpec",
|
||||
"distspec",
|
||||
"DistSpecManager",
|
||||
"ProcessGroup",
|
||||
"ColoTensorSpec",
|
||||
"ShardSpec",
|
||||
"ReplicaSpec",
|
||||
]
|
||||
|
@@ -23,7 +23,7 @@ class ComputeSpec(object):
|
||||
self.output_replicate = True
|
||||
|
||||
def __repr__(self):
|
||||
return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
|
||||
return f"ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})"
|
||||
|
||||
def set_output_replicate(self, flag: bool = True):
|
||||
self.output_replicate = flag
|
||||
|
@@ -3,4 +3,4 @@ from enum import Enum
|
||||
|
||||
class TensorType(Enum):
|
||||
MODEL = 0
|
||||
NONMODEL = 1 # mainly activations
|
||||
NONMODEL = 1 # mainly activations
|
||||
|
@@ -20,14 +20,12 @@ def divide(numerator, denominator):
|
||||
Returns:
|
||||
int: the result of exact division.
|
||||
"""
|
||||
assert denominator != 0, 'denominator can not be zero'
|
||||
assert numerator % denominator == 0, \
|
||||
'{} is not divisible by {}'.format(numerator, denominator)
|
||||
assert denominator != 0, "denominator can not be zero"
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class TransformDistSpec(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
|
||||
ctx.old_dist_spec = old_dist_spec
|
||||
@@ -38,12 +36,17 @@ class TransformDistSpec(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
|
||||
ctx.pg), None, None, None, None, None
|
||||
return (
|
||||
ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, ctx.pg),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
class DistSpecManager:
|
||||
|
||||
_use_autograd_function: bool = True
|
||||
|
||||
@staticmethod
|
||||
@@ -51,8 +54,9 @@ class DistSpecManager:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
def _shard_as(
|
||||
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
|
||||
) -> torch.Tensor:
|
||||
"""_shard_as: shard the tensor w.r.t a distributed specification.
|
||||
Assuming the tensor passed in is a global (replicated) tensor.
|
||||
Args:
|
||||
@@ -62,7 +66,9 @@ class DistSpecManager:
|
||||
Returns:
|
||||
torch.Tensor: a torch tensor after sharded.
|
||||
"""
|
||||
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
|
||||
assert (
|
||||
old_dist_spec.placement.value == "r"
|
||||
), f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
|
||||
chunk = tensor
|
||||
@@ -86,9 +92,9 @@ class DistSpecManager:
|
||||
Returns:
|
||||
torch.Tensor: a replicated tensor.
|
||||
"""
|
||||
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
||||
assert old_dist_spec.placement.value == "s", f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
|
||||
is_cpu_tensor = False
|
||||
if tensor.device.type == 'cpu':
|
||||
if tensor.device.type == "cpu":
|
||||
# pytorch lower than 1.11 dose not support gather a cpu tensor.
|
||||
# Therefore, we transfer tensor to GPU before gather.
|
||||
saved_dev = tensor.device
|
||||
@@ -96,14 +102,14 @@ class DistSpecManager:
|
||||
is_cpu_tensor = True
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
assert tensor.device.type == "cuda"
|
||||
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
new_buffer = []
|
||||
dim = old_dist_spec.dims[i]
|
||||
num_parts = old_dist_spec.num_partitions[i]
|
||||
for start in range(0, len(buffer), num_parts):
|
||||
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
|
||||
new_buffer.append(torch.cat(buffer[start : start + num_parts], dim))
|
||||
buffer = new_buffer
|
||||
assert len(buffer) == 1
|
||||
|
||||
@@ -112,15 +118,17 @@ class DistSpecManager:
|
||||
return buffer[0]
|
||||
|
||||
@staticmethod
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
def _all_to_all(
|
||||
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
|
||||
) -> torch.Tensor:
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
assert tensor.device.type == "cuda", (
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll "
|
||||
f"collective function, however, we got {tensor.device.type} device"
|
||||
)
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
@@ -164,8 +172,9 @@ class DistSpecManager:
|
||||
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
|
||||
|
||||
@staticmethod
|
||||
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
|
||||
pg: ProcessGroup) -> torch.Tensor:
|
||||
def handle_trans_spec(
|
||||
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
|
||||
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
|
||||
|
||||
@@ -174,7 +183,7 @@ class DistSpecManager:
|
||||
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
|
||||
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
|
||||
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
|
||||
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s
|
||||
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s,
|
||||
}
|
||||
|
||||
forward_trans_handle = trans_funcs[trans_func_key]
|
||||
@@ -183,8 +192,9 @@ class DistSpecManager:
|
||||
|
||||
backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
|
||||
|
||||
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
|
||||
backward_trans_handle)
|
||||
return TransformDistSpec.apply(
|
||||
tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
|
@@ -1,12 +1,12 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
__all__ = ['ReplicaSpec', 'ShardSpec']
|
||||
__all__ = ["ReplicaSpec", "ShardSpec"]
|
||||
|
||||
|
||||
class DistPlacementPattern(Enum):
|
||||
REPLICATE = 'r'
|
||||
SHARD = 's'
|
||||
REPLICATE = "r"
|
||||
SHARD = "s"
|
||||
|
||||
|
||||
class _DistSpec:
|
||||
@@ -25,7 +25,6 @@ class _DistSpec:
|
||||
"""
|
||||
|
||||
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
|
||||
|
||||
self.placement = dist_placement_pattern
|
||||
for k, v in meta_info.items():
|
||||
setattr(self, k, v)
|
||||
@@ -34,15 +33,15 @@ class _DistSpec:
|
||||
if dir(self) != dir(other):
|
||||
return False
|
||||
for attr in dir(self):
|
||||
if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr):
|
||||
if not attr.startswith("__") and getattr(self, attr) != getattr(other, attr):
|
||||
return False
|
||||
return True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
attr_list = []
|
||||
for attr in dir(self):
|
||||
if not attr.startswith('__'):
|
||||
attr_list.append(f'{attr}={str(getattr(self, attr))}')
|
||||
if not attr.startswith("__"):
|
||||
attr_list.append(f"{attr}={str(getattr(self, attr))}")
|
||||
attr_str = ", ".join(attr_list)
|
||||
return "DistSpec(" + attr_str + ")"
|
||||
|
||||
|
@@ -7,13 +7,12 @@ from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
|
||||
def __init__(self):
|
||||
# distributed settings
|
||||
# use this dict to record all Pytorch ProcessGroups
|
||||
self.dict = {}
|
||||
# set a distributed logger
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger = get_dist_logger("ProcessGroup")
|
||||
|
||||
def log_pg_init(self, rank_list: List[int], backend: str):
|
||||
str_list = ["Pytorch ProcessGroup Init:"]
|
||||
@@ -21,9 +20,8 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
str_list.append(f"ranks: {rank_list}")
|
||||
self.logger.info("\n\t".join(str_list), ranks=[0])
|
||||
|
||||
def get(self, rank_list: List[int], backend: str = 'nccl'):
|
||||
"""Reuse Pytorch ProcessGroup when such a group is initialized
|
||||
"""
|
||||
def get(self, rank_list: List[int], backend: str = "nccl"):
|
||||
"""Reuse Pytorch ProcessGroup when such a group is initialized"""
|
||||
# we need to convert the passed list to a tuple
|
||||
# since List is unhashable
|
||||
processgroup_key = (backend, tuple(rank_list))
|
||||
@@ -51,11 +49,13 @@ class ProcessGroup:
|
||||
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None,
|
||||
) -> None:
|
||||
if not torch.distributed.is_initialized():
|
||||
self.is_init = False
|
||||
return
|
||||
@@ -64,13 +64,13 @@ class ProcessGroup:
|
||||
|
||||
self._rank = torch.distributed.get_rank()
|
||||
if rank is not None:
|
||||
assert self._rank == rank # make sure that the global rank is correct
|
||||
assert self._rank == rank # make sure that the global rank is correct
|
||||
|
||||
if ranks is None:
|
||||
self._rank_list = list(range(torch.distributed.get_world_size()))
|
||||
else:
|
||||
self._rank_list = ranks
|
||||
self._rank_list.sort() # ensure that the list is in order
|
||||
self._rank_list.sort() # ensure that the list is in order
|
||||
|
||||
self._world_size = len(self._rank_list)
|
||||
|
||||
@@ -79,31 +79,36 @@ class ProcessGroup:
|
||||
self._tp_degree = 1
|
||||
elif dp_degree and not tp_degree:
|
||||
self._dp_degree = dp_degree
|
||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
assert (
|
||||
self._world_size % self._dp_degree == 0
|
||||
), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
self._tp_degree = self._world_size // dp_degree
|
||||
elif not dp_degree and tp_degree:
|
||||
self._tp_degree = tp_degree
|
||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
assert (
|
||||
self._world_size % self._tp_degree == 0
|
||||
), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
self._dp_degree = self._world_size // tp_degree
|
||||
else:
|
||||
self._dp_degree = dp_degree
|
||||
self._tp_degree = tp_degree
|
||||
assert self._dp_degree * self._tp_degree == self._world_size, \
|
||||
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
|
||||
assert self._dp_degree * self._tp_degree == self._world_size, (
|
||||
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}"
|
||||
f"and TP degree {self._tp_degree}"
|
||||
)
|
||||
|
||||
self._tp_rank_list = None
|
||||
self._dp_rank_list = None
|
||||
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(i_tp_list, "nccl")
|
||||
if self._rank in i_tp_list:
|
||||
self._tp_rank_list = i_tp_list
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(j_dp_list, "nccl")
|
||||
if self._rank in j_dp_list:
|
||||
self._dp_rank_list = j_dp_list
|
||||
|
||||
@@ -119,11 +124,11 @@ class ProcessGroup:
|
||||
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(i_tp_list, "gloo")
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(j_dp_list, "gloo")
|
||||
|
||||
self._has_cpu_groups = True
|
||||
|
||||
@@ -145,7 +150,7 @@ class ProcessGroup:
|
||||
else:
|
||||
return "ProcessGroup not initialized"
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
def __eq__(self, obj: "ProcessGroup") -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
@@ -260,7 +265,7 @@ class ProcessGroup:
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl")
|
||||
|
||||
def tp_process_group(self):
|
||||
"""tp_process_group
|
||||
@@ -270,7 +275,7 @@ class ProcessGroup:
|
||||
Returns:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl")
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
"""cpu_dp_process_group
|
||||
@@ -283,7 +288,7 @@ class ProcessGroup:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo")
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
"""cpu_tp_process_group
|
||||
@@ -296,7 +301,7 @@ class ProcessGroup:
|
||||
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
|
||||
"""
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo")
|
||||
|
||||
def get_ranks_in_dp(self) -> List[int]:
|
||||
"""get_ranks_in_dp
|
||||
|
@@ -9,12 +9,13 @@ from .compute_spec import ComputeSpec
|
||||
|
||||
@dataclass
|
||||
class ColoTensorSpec:
|
||||
""" ColoTensorSpec
|
||||
"""ColoTensorSpec
|
||||
|
||||
A data class for specifications of the `ColoTensor`.
|
||||
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
|
||||
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
|
||||
"""
|
||||
|
||||
pg: ProcessGroup
|
||||
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
|
||||
compute_attr: Optional[ComputeSpec] = None
|
||||
|
Reference in New Issue
Block a user