[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -6,12 +6,12 @@ from .process_group import ProcessGroup
from .tensor_spec import ColoTensorSpec
__all__ = [
'ComputePattern',
'ComputeSpec',
'distspec',
'DistSpecManager',
'ProcessGroup',
'ColoTensorSpec',
'ShardSpec',
'ReplicaSpec',
"ComputePattern",
"ComputeSpec",
"distspec",
"DistSpecManager",
"ProcessGroup",
"ColoTensorSpec",
"ShardSpec",
"ReplicaSpec",
]

View File

@@ -23,7 +23,7 @@ class ComputeSpec(object):
self.output_replicate = True
def __repr__(self):
return f'ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})'
return f"ComputeSpec(pattern={self.compute_pattern}, replicate_output={self.output_replicate})"
def set_output_replicate(self, flag: bool = True):
self.output_replicate = flag

View File

@@ -3,4 +3,4 @@ from enum import Enum
class TensorType(Enum):
MODEL = 0
NONMODEL = 1 # mainly activations
NONMODEL = 1 # mainly activations

View File

@@ -20,14 +20,12 @@ def divide(numerator, denominator):
Returns:
int: the result of exact division.
"""
assert denominator != 0, 'denominator can not be zero'
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
assert denominator != 0, "denominator can not be zero"
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)
return numerator // denominator
class TransformDistSpec(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, old_dist_spec, dist_spec, pg, forward_trans_func, backward_trans_func):
ctx.old_dist_spec = old_dist_spec
@@ -38,12 +36,17 @@ class TransformDistSpec(torch.autograd.Function):
@staticmethod
def backward(ctx, grad_outputs):
return ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec,
ctx.pg), None, None, None, None, None
return (
ctx.backward_trans_func(grad_outputs, ctx.dist_spec, ctx.old_dist_spec, ctx.pg),
None,
None,
None,
None,
None,
)
class DistSpecManager:
_use_autograd_function: bool = True
@staticmethod
@@ -51,8 +54,9 @@ class DistSpecManager:
pass
@staticmethod
def _shard_as(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
def _shard_as(
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
) -> torch.Tensor:
"""_shard_as: shard the tensor w.r.t a distributed specification.
Assuming the tensor passed in is a global (replicated) tensor.
Args:
@@ -62,7 +66,9 @@ class DistSpecManager:
Returns:
torch.Tensor: a torch tensor after sharded.
"""
assert old_dist_spec.placement.value == 'r', f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
assert (
old_dist_spec.placement.value == "r"
), f"The old_dist_spec of DistSpecManager._shard_as must be REPLICATE!"
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
@@ -86,9 +92,9 @@ class DistSpecManager:
Returns:
torch.Tensor: a replicated tensor.
"""
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
assert old_dist_spec.placement.value == "s", f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
is_cpu_tensor = False
if tensor.device.type == 'cpu':
if tensor.device.type == "cpu":
# pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device
@@ -96,14 +102,14 @@ class DistSpecManager:
is_cpu_tensor = True
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == 'cuda'
assert tensor.device.type == "cuda"
dist.all_gather(buffer, tensor, group=pg.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
num_parts = old_dist_spec.num_partitions[i]
for start in range(0, len(buffer), num_parts):
new_buffer.append(torch.cat(buffer[start:start + num_parts], dim))
new_buffer.append(torch.cat(buffer[start : start + num_parts], dim))
buffer = new_buffer
assert len(buffer) == 1
@@ -112,15 +118,17 @@ class DistSpecManager:
return buffer[0]
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
def _all_to_all(
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
) -> torch.Tensor:
world_size = pg.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
assert tensor.device.type == "cuda", (
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll "
f"collective function, however, we got {tensor.device.type} device"
)
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]
@@ -164,8 +172,9 @@ class DistSpecManager:
return DistSpecManager._shard_as(tensor, old_dist_spec, dist_spec, pg)
@staticmethod
def handle_trans_spec(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec,
pg: ProcessGroup) -> torch.Tensor:
def handle_trans_spec(
tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec, pg: ProcessGroup
) -> torch.Tensor:
assert isinstance(old_dist_spec, _DistSpec), f"{type(old_dist_spec)} should be _DistSpec"
assert isinstance(dist_spec, _DistSpec), f"{type(dist_spec)} should be _DistSpec"
@@ -174,7 +183,7 @@ class DistSpecManager:
(DistPlacementPattern.REPLICATE, DistPlacementPattern.REPLICATE): DistSpecManager._r2r,
(DistPlacementPattern.REPLICATE, DistPlacementPattern.SHARD): DistSpecManager._r2s,
(DistPlacementPattern.SHARD, DistPlacementPattern.REPLICATE): DistSpecManager._s2r,
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s
(DistPlacementPattern.SHARD, DistPlacementPattern.SHARD): DistSpecManager._s2s,
}
forward_trans_handle = trans_funcs[trans_func_key]
@@ -183,8 +192,9 @@ class DistSpecManager:
backward_trans_handle = trans_funcs[(dist_spec.placement, old_dist_spec.placement)]
return TransformDistSpec.apply(tensor, old_dist_spec, dist_spec, pg, forward_trans_handle,
backward_trans_handle)
return TransformDistSpec.apply(
tensor, old_dist_spec, dist_spec, pg, forward_trans_handle, backward_trans_handle
)
@staticmethod
@contextmanager

View File

@@ -1,12 +1,12 @@
from enum import Enum
from typing import List
__all__ = ['ReplicaSpec', 'ShardSpec']
__all__ = ["ReplicaSpec", "ShardSpec"]
class DistPlacementPattern(Enum):
REPLICATE = 'r'
SHARD = 's'
REPLICATE = "r"
SHARD = "s"
class _DistSpec:
@@ -25,7 +25,6 @@ class _DistSpec:
"""
def __init__(self, dist_placement_pattern: DistPlacementPattern, **meta_info):
self.placement = dist_placement_pattern
for k, v in meta_info.items():
setattr(self, k, v)
@@ -34,15 +33,15 @@ class _DistSpec:
if dir(self) != dir(other):
return False
for attr in dir(self):
if not attr.startswith('__') and getattr(self, attr) != getattr(other, attr):
if not attr.startswith("__") and getattr(self, attr) != getattr(other, attr):
return False
return True
def __repr__(self) -> str:
attr_list = []
for attr in dir(self):
if not attr.startswith('__'):
attr_list.append(f'{attr}={str(getattr(self, attr))}')
if not attr.startswith("__"):
attr_list.append(f"{attr}={str(getattr(self, attr))}")
attr_str = ", ".join(attr_list)
return "DistSpec(" + attr_str + ")"

View File

@@ -7,13 +7,12 @@ from colossalai.logging import get_dist_logger
class PyTorchProcessGroupDict(metaclass=SingletonMeta):
def __init__(self):
# distributed settings
# use this dict to record all Pytorch ProcessGroups
self.dict = {}
# set a distributed logger
self.logger = get_dist_logger('ProcessGroup')
self.logger = get_dist_logger("ProcessGroup")
def log_pg_init(self, rank_list: List[int], backend: str):
str_list = ["Pytorch ProcessGroup Init:"]
@@ -21,9 +20,8 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
str_list.append(f"ranks: {rank_list}")
self.logger.info("\n\t".join(str_list), ranks=[0])
def get(self, rank_list: List[int], backend: str = 'nccl'):
"""Reuse Pytorch ProcessGroup when such a group is initialized
"""
def get(self, rank_list: List[int], backend: str = "nccl"):
"""Reuse Pytorch ProcessGroup when such a group is initialized"""
# we need to convert the passed list to a tuple
# since List is unhashable
processgroup_key = (backend, tuple(rank_list))
@@ -51,11 +49,13 @@ class ProcessGroup:
dp_degree: Optional[int], data parallelism degree. How many processes are inside a dp process group. . default None means len(ranks).
"""
def __init__(self,
rank: Optional[int] = None,
ranks: Optional[List[int]] = None,
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None) -> None:
def __init__(
self,
rank: Optional[int] = None,
ranks: Optional[List[int]] = None,
tp_degree: Optional[int] = None,
dp_degree: Optional[int] = None,
) -> None:
if not torch.distributed.is_initialized():
self.is_init = False
return
@@ -64,13 +64,13 @@ class ProcessGroup:
self._rank = torch.distributed.get_rank()
if rank is not None:
assert self._rank == rank # make sure that the global rank is correct
assert self._rank == rank # make sure that the global rank is correct
if ranks is None:
self._rank_list = list(range(torch.distributed.get_world_size()))
else:
self._rank_list = ranks
self._rank_list.sort() # ensure that the list is in order
self._rank_list.sort() # ensure that the list is in order
self._world_size = len(self._rank_list)
@@ -79,31 +79,36 @@ class ProcessGroup:
self._tp_degree = 1
elif dp_degree and not tp_degree:
self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
assert (
self._world_size % self._dp_degree == 0
), f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size // dp_degree
elif not dp_degree and tp_degree:
self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
assert (
self._world_size % self._tp_degree == 0
), f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size // tp_degree
else:
self._dp_degree = dp_degree
self._tp_degree = tp_degree
assert self._dp_degree * self._tp_degree == self._world_size, \
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
assert self._dp_degree * self._tp_degree == self._world_size, (
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}"
f"and TP degree {self._tp_degree}"
)
self._tp_rank_list = None
self._dp_rank_list = None
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
PYTORCHPGDICT_.get(i_tp_list, "nccl")
if self._rank in i_tp_list:
self._tp_rank_list = i_tp_list
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
PYTORCHPGDICT_.get(j_dp_list, "nccl")
if self._rank in j_dp_list:
self._dp_rank_list = j_dp_list
@@ -119,11 +124,11 @@ class ProcessGroup:
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'gloo')
PYTORCHPGDICT_.get(i_tp_list, "gloo")
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'gloo')
PYTORCHPGDICT_.get(j_dp_list, "gloo")
self._has_cpu_groups = True
@@ -145,7 +150,7 @@ class ProcessGroup:
else:
return "ProcessGroup not initialized"
def __eq__(self, obj: 'ProcessGroup') -> bool:
def __eq__(self, obj: "ProcessGroup") -> bool:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
@@ -260,7 +265,7 @@ class ProcessGroup:
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
return PYTORCHPGDICT_.get(self._dp_rank_list, "nccl")
def tp_process_group(self):
"""tp_process_group
@@ -270,7 +275,7 @@ class ProcessGroup:
Returns:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
return PYTORCHPGDICT_.get(self._tp_rank_list, "nccl")
def cpu_dp_process_group(self):
"""cpu_dp_process_group
@@ -283,7 +288,7 @@ class ProcessGroup:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch DP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
return PYTORCHPGDICT_.get(self._dp_rank_list, "gloo")
def cpu_tp_process_group(self):
"""cpu_tp_process_group
@@ -296,7 +301,7 @@ class ProcessGroup:
`torch._C._distributed_c10d.ProcessGroup`: the pytorch TP process group.
"""
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
return PYTORCHPGDICT_.get(self._tp_rank_list, "gloo")
def get_ranks_in_dp(self) -> List[int]:
"""get_ranks_in_dp

View File

@@ -9,12 +9,13 @@ from .compute_spec import ComputeSpec
@dataclass
class ColoTensorSpec:
""" ColoTensorSpec
"""ColoTensorSpec
A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
"""
pg: ProcessGroup
dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE)
compute_attr: Optional[ComputeSpec] = None