[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -7,15 +7,15 @@ import torch.distributed as dist
from torch.distributed import ReduceOp
__all__ = [
'CollectiveCommPattern',
'CommSpec',
"CollectiveCommPattern",
"CommSpec",
]
def _all_gather(tensor, comm_spec):
'''
"""
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
"""
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
@@ -31,9 +31,9 @@ def _all_gather(tensor, comm_spec):
def _split(tensor, comm_spec):
'''
"""
Implement shard operation on device mesh based on information provided by comm_spec.
'''
"""
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
@@ -45,9 +45,9 @@ def _split(tensor, comm_spec):
def _all_to_all(tensor, comm_spec):
'''
"""
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
"""
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
@@ -66,9 +66,9 @@ def _all_to_all(tensor, comm_spec):
def _all_reduce(tensor, comm_spec, async_op=False):
'''
"""
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
"""
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
@@ -79,7 +79,7 @@ def _all_reduce(tensor, comm_spec, async_op=False):
def _mix_gather(tensor, comm_spec):
'''
"""
Implement mix gather operation on device mesh based on information provided by comm_spec.
Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is
different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather
@@ -124,7 +124,7 @@ def _mix_gather(tensor, comm_spec):
leading_group_dim = 1
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
'''
"""
total_slices = comm_spec.device_mesh.shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0]
@@ -155,15 +155,16 @@ def _mix_gather(tensor, comm_spec):
torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1])
]
for i in range(cat_slice[1]):
tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]),
comm_spec.gather_dim[0]).contiguous()
tmp_tensor_list[i] = torch.cat(
tuple(tensor_list[i * cat_slice[0] : (i + 1) * cat_slice[0]]), comm_spec.gather_dim[0]
).contiguous()
output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous()
return output
def _mix_split(tensor, comm_spec):
'''
"""
Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent)
Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split
because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension.
@@ -177,7 +178,7 @@ def _mix_split(tensor, comm_spec):
# [[0, 1, 2, 3],
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
"""
mesh_shape = comm_spec.device_meshes.shape
dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.shape[0]
@@ -316,11 +317,13 @@ class _AllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
sharding_spec=comm_spec.sharding_spec,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)
comm_spec_for_backward = CommSpec(
comm_pattern=comm_spec.comm_pattern,
sharding_spec=comm_spec.sharding_spec,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis,
)
ctx.comm_spec = comm_spec_for_backward
return output
@@ -330,7 +333,6 @@ class _AllToAll(torch.autograd.Function):
class _MixGatherForwardMixSplitBackward(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_):
return _mix_gather(input_)
@@ -370,16 +372,16 @@ def mixgather_forward_split_backward(input_, comm_spec):
class CollectiveCommPattern(Enum):
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd"
ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd"
SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd"
ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd"
IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd"
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
class CommSpec:
'''
"""
Communication spec is used to record the communication action. It has two main functions:
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
@@ -393,16 +395,18 @@ class CommSpec:
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
"""
def __init__(self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False,
mix_gather=False):
def __init__(
self,
comm_pattern,
sharding_spec,
gather_dim=None,
shard_dim=None,
logical_process_axis=None,
forward_only=False,
mix_gather=False,
):
self.comm_pattern = comm_pattern
self.sharding_spec = sharding_spec
self.gather_dim = gather_dim
@@ -449,14 +453,14 @@ class CommSpec:
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_asex:{self.logical_process_axes})")
return ''.join(res_list)
return "".join(res_list)
def get_comm_cost(self):
'''
"""
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
"""
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
cost_dict = {}
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
@@ -500,13 +504,13 @@ class CommSpec:
return cost_dict
def covert_spec_to_action(self, tensor):
'''
"""
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
"""
if self.comm_pattern in pattern_to_func_dict:
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
else: