mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,15 +7,15 @@ import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
__all__ = [
|
||||
'CollectiveCommPattern',
|
||||
'CommSpec',
|
||||
"CollectiveCommPattern",
|
||||
"CommSpec",
|
||||
]
|
||||
|
||||
|
||||
def _all_gather(tensor, comm_spec):
|
||||
'''
|
||||
"""
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
@@ -31,9 +31,9 @@ def _all_gather(tensor, comm_spec):
|
||||
|
||||
|
||||
def _split(tensor, comm_spec):
|
||||
'''
|
||||
"""
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
@@ -45,9 +45,9 @@ def _split(tensor, comm_spec):
|
||||
|
||||
|
||||
def _all_to_all(tensor, comm_spec):
|
||||
'''
|
||||
"""
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
@@ -66,9 +66,9 @@ def _all_to_all(tensor, comm_spec):
|
||||
|
||||
|
||||
def _all_reduce(tensor, comm_spec, async_op=False):
|
||||
'''
|
||||
"""
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
|
||||
process_group = process_groups[comm_spec.logical_process_axis]
|
||||
|
||||
@@ -79,7 +79,7 @@ def _all_reduce(tensor, comm_spec, async_op=False):
|
||||
|
||||
|
||||
def _mix_gather(tensor, comm_spec):
|
||||
'''
|
||||
"""
|
||||
Implement mix gather operation on device mesh based on information provided by comm_spec.
|
||||
Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is
|
||||
different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather
|
||||
@@ -124,7 +124,7 @@ def _mix_gather(tensor, comm_spec):
|
||||
leading_group_dim = 1
|
||||
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
|
||||
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
|
||||
'''
|
||||
"""
|
||||
total_slices = comm_spec.device_mesh.shape[0]
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
|
||||
leading_group_dim = comm_spec.logical_process_axes[0]
|
||||
@@ -155,15 +155,16 @@ def _mix_gather(tensor, comm_spec):
|
||||
torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1])
|
||||
]
|
||||
for i in range(cat_slice[1]):
|
||||
tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]),
|
||||
comm_spec.gather_dim[0]).contiguous()
|
||||
tmp_tensor_list[i] = torch.cat(
|
||||
tuple(tensor_list[i * cat_slice[0] : (i + 1) * cat_slice[0]]), comm_spec.gather_dim[0]
|
||||
).contiguous()
|
||||
output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _mix_split(tensor, comm_spec):
|
||||
'''
|
||||
"""
|
||||
Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent)
|
||||
Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split
|
||||
because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension.
|
||||
@@ -177,7 +178,7 @@ def _mix_split(tensor, comm_spec):
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7]]
|
||||
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
|
||||
'''
|
||||
"""
|
||||
mesh_shape = comm_spec.device_meshes.shape
|
||||
dim = comm_spec.gather_dim
|
||||
total_slices = comm_spec.device_mesh.shape[0]
|
||||
@@ -316,11 +317,13 @@ class _AllToAll(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
sharding_spec=comm_spec.sharding_spec,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
comm_spec_for_backward = CommSpec(
|
||||
comm_pattern=comm_spec.comm_pattern,
|
||||
sharding_spec=comm_spec.sharding_spec,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis,
|
||||
)
|
||||
ctx.comm_spec = comm_spec_for_backward
|
||||
return output
|
||||
|
||||
@@ -330,7 +333,6 @@ class _AllToAll(torch.autograd.Function):
|
||||
|
||||
|
||||
class _MixGatherForwardMixSplitBackward(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _mix_gather(input_)
|
||||
@@ -370,16 +372,16 @@ def mixgather_forward_split_backward(input_, comm_spec):
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
|
||||
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
|
||||
GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd"
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd"
|
||||
SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd"
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd"
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd"
|
||||
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
|
||||
|
||||
|
||||
class CommSpec:
|
||||
'''
|
||||
"""
|
||||
Communication spec is used to record the communication action. It has two main functions:
|
||||
1. Compute the communication cost which will be used in auto parallel solver.
|
||||
2. Convert the communication spec to real action which will be used in runtime.
|
||||
@@ -393,16 +395,18 @@ class CommSpec:
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern,
|
||||
sharding_spec,
|
||||
gather_dim=None,
|
||||
shard_dim=None,
|
||||
logical_process_axis=None,
|
||||
forward_only=False,
|
||||
mix_gather=False):
|
||||
def __init__(
|
||||
self,
|
||||
comm_pattern,
|
||||
sharding_spec,
|
||||
gather_dim=None,
|
||||
shard_dim=None,
|
||||
logical_process_axis=None,
|
||||
forward_only=False,
|
||||
mix_gather=False,
|
||||
):
|
||||
self.comm_pattern = comm_pattern
|
||||
self.sharding_spec = sharding_spec
|
||||
self.gather_dim = gather_dim
|
||||
@@ -449,14 +453,14 @@ class CommSpec:
|
||||
res_list.append(f"gather_dim:{self.gather_dim}, ")
|
||||
res_list.append(f"logical_process_asex:{self.logical_process_axes})")
|
||||
|
||||
return ''.join(res_list)
|
||||
return "".join(res_list)
|
||||
|
||||
def get_comm_cost(self):
|
||||
'''
|
||||
"""
|
||||
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
|
||||
compute the communication cost.
|
||||
For shard operation, it is an on-chip operation, so the communication cost is zero.
|
||||
'''
|
||||
"""
|
||||
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
|
||||
cost_dict = {}
|
||||
if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
|
||||
@@ -500,13 +504,13 @@ class CommSpec:
|
||||
return cost_dict
|
||||
|
||||
def covert_spec_to_action(self, tensor):
|
||||
'''
|
||||
"""
|
||||
Convert CommSpec into runtime action, implement real collection communication to target tensor.
|
||||
The collection communication action is directed by the CommSpec.
|
||||
|
||||
Argument:
|
||||
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
||||
'''
|
||||
"""
|
||||
if self.comm_pattern in pattern_to_func_dict:
|
||||
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||
else:
|
||||
|
Reference in New Issue
Block a user