diff --git a/colossalai/auto_parallel/solver/_utils.py b/colossalai/auto_parallel/solver/_utils.py index 77faa706a..9cdc984cb 100644 --- a/colossalai/auto_parallel/solver/_utils.py +++ b/colossalai/auto_parallel/solver/_utils.py @@ -79,7 +79,7 @@ def generate_resharding_costs(nodes: List[Node], input_sharding_spec, input_spec) # we need multiply the size of elem dtype to get correct communication cost - resharding_cost = total_resharding_cost * size_per_elem_bytes + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes except AssertionError as e: warnings.warn(f'{e}') resharding_cost = INFINITY_COST diff --git a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py index fb2e53dad..1ca5d6559 100644 --- a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py @@ -93,7 +93,7 @@ class BcastOpHandler(OperatorHandler): input_sharding_spec, input_spec) # we need multiply the size of elem dtype to get correct communication cost - resharding_cost = total_resharding_cost * size_per_elem_bytes + resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes resharding_costs[input_node].append(resharding_cost) return resharding_costs diff --git a/colossalai/auto_parallel/solver/strategy/strategy_generator.py b/colossalai/auto_parallel/solver/strategy/strategy_generator.py index 95d5ff26d..0cdc8b018 100644 --- a/colossalai/auto_parallel/solver/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/strategy_generator.py @@ -91,18 +91,11 @@ class StrategyGenerator_V2(ABC): num_ele_in_comm = comm_spec.get_comm_cost() dtype = operand.data.dtype size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - cost = size_per_elem_bytes * num_ele_in_comm - - # compute the fwd - # TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost. - # it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used, - # so total cost is either for fwd or bwd. - if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: - comm_cost.fwd += cost - elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: - comm_cost.fwd += cost - else: - raise ValueError(f"Found unknown CommunicationType {comm_spec.comm_pattern}") + for phase, cost in num_ele_in_comm.items(): + num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes + comm_cost.fwd += num_ele_in_comm['forward'] + comm_cost.bwd += num_ele_in_comm['backward'] + comm_cost.total += num_ele_in_comm['total'] # check if communication action exists # if so, loop over each action and compute the cost of each action @@ -110,9 +103,6 @@ class StrategyGenerator_V2(ABC): for operand, comm_spec in strategy.communication_actions.items(): _compute_and_add(operand, comm_spec) - # update the total cost - comm_cost.total = comm_cost.fwd + comm_cost.bwd - # update the communication cost attribute in-place strategy.communication_cost = comm_cost return strategy diff --git a/colossalai/tensor/__init__.py b/colossalai/tensor/__init__.py index bf0e4bf34..4946d7077 100644 --- a/colossalai/tensor/__init__.py +++ b/colossalai/tensor/__init__.py @@ -9,10 +9,11 @@ from .colo_parameter import ColoParameter from .utils import convert_parameter, named_params_with_colotensor from .dist_spec_mgr import DistSpecManager from .param_op_hook import ParamOpHook, ParamOpHookManager +from .comm_spec import CollectiveCommPattern, CommSpec from . import distspec __all__ = [ 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', - 'ReplicaSpec' + 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern' ] diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py new file mode 100644 index 000000000..8f51f21cf --- /dev/null +++ b/colossalai/tensor/comm_spec.py @@ -0,0 +1,358 @@ +import torch +from enum import Enum +import torch.distributed as dist +from functools import reduce +import operator +from torch.distributed import ReduceOp + +__all__ = [ + 'CollectiveCommPattern', + 'CommSpec', +] + + +def _all_gather(tensor, comm_spec): + ''' + Implement all gather operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + tensor_list = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) + for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) + ] + tensor = tensor + group = process_group + dist.all_gather(tensor_list, tensor, group=group) + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _split(tensor, comm_spec): + ''' + Implement shard operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, _ in process_groups_list: + if dist.get_rank() in rank_list: + tensor = tensor + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + start = length * rank_list.index(dist.get_rank()) + output = torch.narrow(tensor, dim, start, length) + return output + + +def _all_to_all(tensor, comm_spec): + ''' + Implement all to all operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + new_shape = list(tensor.shape) + new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) + new_shape = torch.Size(new_shape) + output_tensor_list = [ + torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) + ] + dim = comm_spec.shard_dim + length = tensor.shape[comm_spec.shard_dim] // len(rank_list) + input_tensor_list = [ + torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) + ] + group = process_group + dist.all_to_all(output_tensor_list, input_tensor_list, group) + output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() + return output + + +def _all_reduce(tensor, comm_spec): + ''' + Implement all reduce operation on device mesh based on information provided by comm_spec. + ''' + process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] + for rank_list, process_group in process_groups_list: + if dist.get_rank() in rank_list: + dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) + return tensor + + +class _ReduceGrad(torch.autograd.Function): + """ + A customized communication operation which forward is an identity operation, + backward is all_reduce operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return input_ + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return input_ + + @staticmethod + def backward(ctx, grad_output): + return _all_reduce(grad_output, ctx.comm_spec), None + + +class _ReduceInput(torch.autograd.Function): + """ + A customized communication operation which forward is all_reduce operation, + backward is an identity operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_reduce(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + return _all_reduce(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return grad_output, None + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + A customized communication operation which forward is split operation, + backward is an all gather operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _split(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _split(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _all_gather(grad_output, ctx.comm_spec), None + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """ + A customized communication operation which forward is an all gather operation, + backward is split operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _all_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.comm_spec), None + + +class _AllToAll(torch.autograd.Function): + """ + A customized communication operation which forward is an all to all operation, + backward is an all to all operation. + + Args: + input_: input matrix. + comm_spec: comm_spec will give information like process group, rank list, etc. + """ + + @staticmethod + def symbolic(graph, input_): + return _all_to_all(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + output = _all_to_all(input_, comm_spec) + comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, + sharding_spec=comm_spec.sharding_spec, + gather_dim=comm_spec.shard_dim, + shard_dim=comm_spec.gather_dim, + logical_process_axis=comm_spec.logical_process_axis) + ctx.comm_spec = comm_spec_for_backward + return output + + @staticmethod + def backward(ctx, grad_outputs): + return _all_to_all(grad_outputs, ctx.comm_spec), None + + +def reduce_grad(input_, comm_spec): + return _ReduceGrad.apply(input_, comm_spec) + + +def reduce_input(input_, comm_spec): + return _ReduceInput.apply(input_, comm_spec) + + +def split_forward_gather_backward(input_, comm_spec): + return _SplitForwardGatherBackward.apply(input_, comm_spec) + + +def gather_forward_split_backward(input_, comm_spec): + return _GatherForwardSplitBackward.apply(input_, comm_spec) + + +def all_to_all(input_, comm_spec): + return _AllToAll.apply(input_, comm_spec) + + +class CollectiveCommPattern(Enum): + GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' + ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' + SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' + ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' + IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + + +class CommSpec: + ''' + Communication spec is used to record the communication action. It has two main functions: + 1. Compute the communication cost which will be used in auto parallel solver. + 2. Convert the communication spec to real action which will be used in runtime. + It contains comm_pattern to determine the + communication method, sharding_spec to determine the communication size, gather_dim and shard_dim + to determine the buffer shape, and logical_process_axis + + Argument: + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. + gather_dim(int, Optional): The gather_dim of the tensor will be gathered. + shard_dim(int, Optional): The shard_dim of the tensor will be sharded. + logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. + ''' + + def __init__(self, + comm_pattern, + sharding_spec, + gather_dim=None, + shard_dim=None, + logical_process_axis=None, + forward_only=False): + self.comm_pattern = comm_pattern + self.sharding_spec = sharding_spec + self.gather_dim = gather_dim + self.shard_dim = shard_dim + self.logical_process_axis = logical_process_axis + self.forward_only = forward_only + if isinstance(self.logical_process_axis, list): + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.logical_process_axis = 0 + else: + self.device_mesh = self.sharding_spec.device_mesh + + def __repr__(self): + res_list = ["CommSpec:("] + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis: {self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + + return ''.join(res_list) + + def get_comm_cost(self): + ''' + For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. + For shard operation, it is an on-chip operation, so the communication cost is zero. + ''' + comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) + cost_dict = {} + if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + # give a tiny cost to shard + backward_communication_cost = 10 + + if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + # grad should have same shape as input tensor + # all to all operation has same logical process axis as forward. + backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + backward_communication_cost = 0 + + if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + forward_communication_cost = 0 + backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) + + if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + # give a tiny cost to shard + forward_communication_cost = 10 + backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + + if self.forward_only: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = 0 + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + else: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = backward_communication_cost + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + + return cost_dict + + def covert_spec_to_action(self, tensor): + ''' + Convert CommSpec into runtime action, implement real collection communication to target tensor. + The collection communication action is directed by the CommSpec. + + Argument: + tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. + ''' + if self.comm_pattern in pattern_to_func_dict: + tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self) + else: + tensor.data = tensor + + +pattern_to_func_dict = { + CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, + CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, + CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, + CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, + CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, +} diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index d094a2c37..1f7dd2935 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -11,355 +11,9 @@ import math from functools import reduce import operator from torch.distributed import ReduceOp +from .comm_spec import * -__all__ = [ - 'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions', - 'set_shape_consistency_options' -] - - -def _all_gather(tensor, comm_spec): - ''' - Implement all gather operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - tensor_list = [ - torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) - for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) - ] - tensor = tensor - group = process_group - dist.all_gather(tensor_list, tensor, group=group) - output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() - return output - - -def _split(tensor, comm_spec): - ''' - Implement shard operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, _ in process_groups_list: - if dist.get_rank() in rank_list: - tensor = tensor - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - start = length * rank_list.index(dist.get_rank()) - output = torch.narrow(tensor, dim, start, length) - return output - - -def _all_to_all(tensor, comm_spec): - ''' - Implement all to all operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - new_shape = list(tensor.shape) - new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) - new_shape = torch.Size(new_shape) - output_tensor_list = [ - torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) - ] - dim = comm_spec.shard_dim - length = tensor.shape[comm_spec.shard_dim] // len(rank_list) - input_tensor_list = [ - torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) - ] - group = process_group - dist.all_to_all(output_tensor_list, input_tensor_list, group) - output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() - return output - - -def _all_reduce(tensor, comm_spec): - ''' - Implement all reduce operation on device mesh based on information provided by comm_spec. - ''' - process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] - for rank_list, process_group in process_groups_list: - if dist.get_rank() in rank_list: - dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) - return tensor - - -class _ReduceGrad(torch.autograd.Function): - """ - A customized communication operation which forward is an identity operation, - backward is all_reduce operation. - - Args: - input_: input matrix. - comm_spec: comm_spec will give information like process group, rank list, etc. - """ - - @staticmethod - def symbolic(graph, input_): - return input_ - - @staticmethod - def forward(ctx, input_, comm_spec): - ctx.comm_spec = comm_spec - return input_ - - @staticmethod - def backward(ctx, grad_output): - return _all_reduce(grad_output, ctx.comm_spec), None - - -class _ReduceInput(torch.autograd.Function): - """ - A customized communication operation which forward is all_reduce operation, - backward is an identity operation. - - Args: - input_: input matrix. - comm_spec: comm_spec will give information like process group, rank list, etc. - """ - - @staticmethod - def symbolic(graph, input_): - return _all_reduce(input_) - - @staticmethod - def forward(ctx, input_, comm_spec): - return _all_reduce(input_, comm_spec) - - @staticmethod - def backward(ctx, grad_output): - return grad_output, None - - -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - A customized communication operation which forward is split operation, - backward is an all gather operation. - - Args: - input_: input matrix. - comm_spec: comm_spec will give information like process group, rank list, etc. - """ - - @staticmethod - def symbolic(graph, input_): - return _split(input_) - - @staticmethod - def forward(ctx, input_, comm_spec): - ctx.comm_spec = comm_spec - return _split(input_, comm_spec) - - @staticmethod - def backward(ctx, grad_output): - return _all_gather(grad_output, ctx.comm_spec), None - - -class _GatherForwardSplitBackward(torch.autograd.Function): - """ - A customized communication operation which forward is an all gather operation, - backward is split operation. - - Args: - input_: input matrix. - comm_spec: comm_spec will give information like process group, rank list, etc. - """ - - @staticmethod - def symbolic(graph, input_): - return _all_gather(input_) - - @staticmethod - def forward(ctx, input_, comm_spec): - ctx.comm_spec = comm_spec - return _all_gather(input_, comm_spec) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.comm_spec), None - - -class _AllToAll(torch.autograd.Function): - """ - A customized communication operation which forward is an all to all operation, - backward is an all to all operation. - - Args: - input_: input matrix. - comm_spec: comm_spec will give information like process group, rank list, etc. - """ - - @staticmethod - def symbolic(graph, input_): - return _all_to_all(input_) - - @staticmethod - def forward(ctx, input_, comm_spec): - output = _all_to_all(input_, comm_spec) - comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, - sharding_spec=comm_spec.sharding_spec, - gather_dim=comm_spec.shard_dim, - shard_dim=comm_spec.gather_dim, - logical_process_axis=comm_spec.logical_process_axis) - ctx.comm_spec = comm_spec_for_backward - return output - - @staticmethod - def backward(ctx, grad_outputs): - return _all_to_all(grad_outputs, ctx.comm_spec), None - - -def reduce_grad(input_, comm_spec): - return _ReduceGrad.apply(input_, comm_spec) - - -def reduce_input(input_, comm_spec): - return _ReduceInput.apply(input_, comm_spec) - - -def split_forward_gather_backward(input_, comm_spec): - return _SplitForwardGatherBackward.apply(input_, comm_spec) - - -def gather_forward_split_backward(input_, comm_spec): - return _GatherForwardSplitBackward.apply(input_, comm_spec) - - -def all_to_all(input_, comm_spec): - return _AllToAll.apply(input_, comm_spec) - - -class CollectiveCommPattern(Enum): - GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' - ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' - SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' - ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' - IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' - - -class CommSpec: - ''' - Communication spec is used to record the communication action. It has two main functions: - 1. Compute the communication cost which will be used in auto parallel solver. - 2. Convert the communication spec to real action which will be used in runtime. - It contains comm_pattern to determine the - communication method, sharding_spec to determine the communication size, gather_dim and shard_dim - to determine the buffer shape, and logical_process_axis - - Argument: - comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. - sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. - gather_dim(int, Optional): The gather_dim of the tensor will be gathered. - shard_dim(int, Optional): The shard_dim of the tensor will be sharded. - logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. - ''' - - def __init__(self, - comm_pattern, - sharding_spec, - gather_dim=None, - shard_dim=None, - logical_process_axis=None, - forward_only=False): - self.comm_pattern = comm_pattern - self.sharding_spec = sharding_spec - self.gather_dim = gather_dim - self.shard_dim = shard_dim - self.logical_process_axis = logical_process_axis - self.forward_only = forward_only - if isinstance(self.logical_process_axis, list): - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh - self.logical_process_axis = 0 - else: - self.device_mesh = self.sharding_spec.device_mesh - - def __repr__(self): - res_list = ["CommSpec:("] - if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: - res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") - res_list.append(f"gather_dim:{self.gather_dim}, ") - res_list.append(f"logical_process_axis:{self.logical_process_axis})") - elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: - res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") - res_list.append(f"gather_dim:{self.gather_dim}, ") - res_list.append(f"shard_dim:{self.shard_dim}, ") - res_list.append(f"logical_process_axis: {self.logical_process_axis})") - elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: - res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") - res_list.append(f"shard_dim:{self.shard_dim}, ") - res_list.append(f"logical_process_axis:{self.logical_process_axis})") - elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: - res_list.append(f"comm_pattern:ALLREDUCE_FWD_IDENTITY_BWD, ") - res_list.append(f"logical_process_axis:{self.logical_process_axis})") - elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: - res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") - res_list.append(f"logical_process_axis:{self.logical_process_axis})") - - return ''.join(res_list) - - def get_comm_cost(self): - ''' - For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to - compute the communication cost. - For shard operation, it is an on-chip operation, so the communication cost is zero. - ''' - comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) - if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: - forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) - # give a tiny cost to shard - backward_communication_cost = 10 - - if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: - forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) - # grad should have same shape as input tensor - # all to all operation has same logical process axis as forward. - backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) - - if self.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: - forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) - backward_communication_cost = 0 - - if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: - forward_communication_cost = 0 - backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) - - if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: - # give a tiny cost to shard - forward_communication_cost = 10 - backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) - try: - if self.forward_only: - total_communication_cost = forward_communication_cost - else: - total_communication_cost = forward_communication_cost + backward_communication_cost - except: - raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.") - - return total_communication_cost - - def covert_spec_to_action(self, tensor): - ''' - Convert CommSpec into runtime action, implement real collection communication to target tensor. - The collection communication action is directed by the CommSpec. - - Argument: - tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. - ''' - if self.comm_pattern in pattern_to_func_dict: - tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self) - else: - tensor.data = tensor - - -pattern_to_func_dict = { - CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, - CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, - CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, - CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, - CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, -} +__all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] @dataclass @@ -406,7 +60,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): assert isinstance(value, bool) self._forward_only = value - def get_all_all_gather_spec(self, source_spec, orig_cost): + def get_all_all_gather_spec(self, source_spec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with single all-gather operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -463,16 +117,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): forward_only=self.forward_only) # compute the communication cost with CommSpec - cost = comm_spec.get_comm_cost() + cost_dict = comm_spec.get_comm_cost() # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) - valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) return valid_spec_dict - def get_all_all_to_all_spec(self, source_spec, orig_cost): + def get_all_all_to_all_spec(self, source_spec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -552,7 +208,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): forward_only=self.forward_only) # compute the communication cost with CommSpec - cost = comm_spec.get_comm_cost() + cost_dict = comm_spec.get_comm_cost() new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) # We won't add empty list into dim_partition_dict @@ -570,10 +226,12 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) - valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) return valid_spec_dict - def get_all_shard_spec(self, source_spec, orig_cost): + def get_all_shard_spec(self, source_spec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with single shard operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -639,16 +297,18 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): forward_only=self.forward_only) # compute the communication cost with CommSpec - cost = comm_spec.get_comm_cost() + cost_dict = comm_spec.get_comm_cost() # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) - valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) return valid_spec_dict - def get_all_one_step_transform_spec(self, source_spec, orig_cost): + def get_all_one_step_transform_spec(self, source_spec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with one step transform, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. @@ -665,9 +325,9 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. ''' valid_spec_dict = {} - valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost)) - valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost)) - valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost)) + valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) + valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) + valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) return valid_spec_dict def shape_consistency(self, source_spec, target_spec): @@ -730,7 +390,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): total_cost: 12294.402000000002 ''' MAX_TRANSFORM_STEPS = 20 - total_cost = 0 + total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} total_steps = 0 transform_path = [] comm_action_sequence = [] @@ -740,35 +400,37 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): # We do nothing if the sharding spec is all the same. if source_spec.sharding_sequence_difference(target_spec) == 0: self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) - return (transform_path, comm_action_sequence, total_cost) + return (transform_path, comm_action_sequence, total_cost_dict) temp_sharding_spec = source_spec transform_path.append(temp_sharding_spec) # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms while total_steps <= MAX_TRANSFORM_STEPS: - valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost) + valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict) best_difference_score = math.inf for sharding_spec, info_pairs in valid_transform_spec_dict.items(): - comm_spec, cost = info_pairs + comm_spec, cost_dict = info_pairs spec_difference = sharding_spec.sharding_sequence_difference(target_spec) if spec_difference == 0: - total_cost += cost + for phase, cost in total_cost_dict.items(): + total_cost_dict[phase] = cost + cost_dict[phase] transform_path.append(sharding_spec) comm_action_sequence.append(comm_spec) self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) - return (transform_path, comm_action_sequence, total_cost) + return (transform_path, comm_action_sequence, total_cost_dict) if spec_difference < best_difference_score: temp_sharding_spec = sharding_spec - temp_cost = cost + temp_cost_dict = cost_dict temp_comm_spec = comm_spec best_difference_score = spec_difference transform_path.append(temp_sharding_spec) comm_action_sequence.append(temp_comm_spec) - total_cost += temp_cost + for phase, cost in total_cost_dict.items(): + total_cost_dict[phase] = cost + temp_cost_dict[phase] total_steps += 1 raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index c81bee5e0..6fe9ee292 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -27,7 +27,11 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: # shard_sequence: S0,R,R # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} - rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, { + "forward": 0, + "backward": 0, + "total": 0 + }) assert '[R, S1, R]' in [ str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() @@ -48,7 +52,11 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} - rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0) + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, { + "forward": 0, + "backward": 0, + "total": 0 + }) assert '[S01, R, R]' in [ str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() @@ -72,7 +80,11 @@ def test_one_step_transform(): # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: # shard_sequence: S0,R,S1 # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} - rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0) + rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, { + "forward": 0, + "backward": 0, + "total": 0 + }) assert '[S01, R, R]' in [ str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys()