[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -6,22 +6,22 @@ import torch.distributed as dist
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern',
'CommSpec',
"CollectiveCommPattern",
"CommSpec",
]
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd"
ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd"
SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd"
ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd"
IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd"
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
class CommSpec:
'''
"""
Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
@@ -33,14 +33,16 @@ class CommSpec:
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
"""
def __init__(self,
comm_pattern: CollectiveCommPattern,
process_group_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None):
def __init__(
self,
comm_pattern: CollectiveCommPattern,
process_group_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None,
):
self.comm_pattern = comm_pattern
self.gather_dim = gather_dim
self.shard_dim = shard_dim
@@ -71,16 +73,16 @@ class CommSpec:
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
return "".join(res_list)
def covert_spec_to_action(self, tensor):
'''
"""
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
"""
if self.comm_pattern in pattern_to_func_dict:
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
else:
@@ -89,9 +91,9 @@ class CommSpec:
def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
'''
"""
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
@@ -103,9 +105,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
'''
"""
Implement shard operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
@@ -115,9 +117,9 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
'''
"""
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
@@ -134,9 +136,9 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
'''
"""
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
"""
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
if not tensor.is_contiguous():
tensor = tensor.contiguous()
@@ -256,11 +258,13 @@ class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
comm_spec_for_backward = CommSpec(
comm_pattern=comm_spec.comm_pattern,
process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis,
)
ctx.comm_spec = comm_spec_for_backward
return output