mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -6,22 +6,22 @@ import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
__all__ = [
|
||||
'CollectiveCommPattern',
|
||||
'CommSpec',
|
||||
"CollectiveCommPattern",
|
||||
"CommSpec",
|
||||
]
|
||||
|
||||
|
||||
class CollectiveCommPattern(Enum):
|
||||
GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd'
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd'
|
||||
SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd'
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd'
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd'
|
||||
GATHER_FWD_SPLIT_BWD = "gather_fwd_split_bwd"
|
||||
ALL2ALL_FWD_ALL2ALL_BWD = "all2all_fwd_all2all_bwd"
|
||||
SPLIT_FWD_GATHER_BWD = "split_fwd_gather_bwd"
|
||||
ALLREDUCE_FWD_IDENTITY_BWD = "all_reduce_fwd_identity_bwd"
|
||||
IDENTITY_FWD_ALLREDUCE_BWD = "identity_fwd_all_reduce_bwd"
|
||||
MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd"
|
||||
|
||||
|
||||
class CommSpec:
|
||||
'''
|
||||
"""
|
||||
Communication spec is used to record the communication action. It converts the communication spec
|
||||
to real action which will be used in runtime. It contains comm_pattern to determine the
|
||||
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
|
||||
@@ -33,14 +33,16 @@ class CommSpec:
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None):
|
||||
def __init__(
|
||||
self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None,
|
||||
):
|
||||
self.comm_pattern = comm_pattern
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
@@ -71,16 +73,16 @@ class CommSpec:
|
||||
res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ")
|
||||
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
|
||||
|
||||
return ''.join(res_list)
|
||||
return "".join(res_list)
|
||||
|
||||
def covert_spec_to_action(self, tensor):
|
||||
'''
|
||||
"""
|
||||
Convert CommSpec into runtime action, implement real collection communication to target tensor.
|
||||
The collection communication action is directed by the CommSpec.
|
||||
|
||||
Argument:
|
||||
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
|
||||
'''
|
||||
"""
|
||||
if self.comm_pattern in pattern_to_func_dict:
|
||||
tensor = pattern_to_func_dict[self.comm_pattern](tensor, self)
|
||||
else:
|
||||
@@ -89,9 +91,9 @@ class CommSpec:
|
||||
|
||||
|
||||
def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
"""
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
@@ -103,9 +105,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
|
||||
|
||||
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
"""
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
@@ -115,9 +117,9 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
|
||||
|
||||
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
"""
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
new_shape = list(tensor.shape)
|
||||
@@ -134,9 +136,9 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
|
||||
|
||||
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
|
||||
'''
|
||||
"""
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
"""
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
@@ -256,11 +258,13 @@ class _AllToAll(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
comm_spec_for_backward = CommSpec(
|
||||
comm_pattern=comm_spec.comm_pattern,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis,
|
||||
)
|
||||
ctx.comm_spec = comm_spec_for_backward
|
||||
return output
|
||||
|
||||
|
Reference in New Issue
Block a user