From d655eea515c1dc6c68802a86a1b8c49ff32b6038 Mon Sep 17 00:00:00 2001 From: Genghan Zhang <58754328+zhang677@users.noreply.github.com> Date: Wed, 23 Nov 2022 21:49:17 +0800 Subject: [PATCH] [autoparallel] mix gather (#1977) * Add mix-gather * Add comments * Add comments * Polish comments * Change the global rank assumption * Add tests * Add two-step tests * Fix 10 and 01 * Skip test becasue the number of GPUs --- colossalai/device/device_mesh.py | 38 +++ colossalai/tensor/comm_spec.py | 170 ++++++++++++- colossalai/tensor/shape_consistency.py | 55 +++- colossalai/tensor/utils.py | 25 ++ tests/test_tensor/test_mix_gather.py | 333 +++++++++++++++++++++++++ 5 files changed, 617 insertions(+), 4 deletions(-) create mode 100644 tests/test_tensor/test_mix_gather.py diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py index 403bbe4ae..b77fe5eef 100644 --- a/colossalai/device/device_mesh.py +++ b/colossalai/device/device_mesh.py @@ -52,6 +52,9 @@ class DeviceMesh: self.process_groups_dict = self.create_process_groups_for_logical_mesh() if self.need_flatten: self.flatten_device_mesh = self.flatten() + # Create a new member `flatten_device_meshes` to distinguish from original flatten methods (Because I'm not sure if there are functions that rely on the self.flatten()) + self.flatten_device_meshes = FlattenDeviceMesh(self.physical_mesh_id, self.mesh_shape, self.mesh_alpha, + self.mesh_beta) @property def shape(self): @@ -199,3 +202,38 @@ class DeviceMesh: penalty_factor = num_devices / 2.0 return (self.mesh_alpha[mesh_dim] + self.mesh_beta[mesh_dim] * (num_devices - 1) / num_devices / num_devices * num_bytes * penalty_factor + 0.001) + + +class FlattenDeviceMesh(DeviceMesh): + + def __init__(self, physical_mesh_id, mesh_shape, mesh_alpha=None, mesh_beta=None): + super().__init__(physical_mesh_id, + mesh_shape, + mesh_alpha, + mesh_beta, + init_process_group=False, + need_flatten=False) + # Different from flatten(), mesh_shape leaves unchanged, mesh_alpha and mesh_beta are scalars + self.mesh_alpha = max(self.mesh_alpha) + self.mesh_beta = min(self.mesh_beta) + # Different from original process_groups_dict, rank_list is not stored + self.process_number_dict = self.create_process_numbers_for_logical_mesh() + + def create_process_numbers_for_logical_mesh(self): + ''' + Build 1d DeviceMesh in column-major(0) and row-major(1) + for example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + num_devices = reduce(operator.mul, self.mesh_shape, 1) + process_numbers_dict = {} + process_numbers_dict[0] = torch.arange(num_devices).reshape(self.mesh_shape).transpose(1, 0).flatten().tolist() + process_numbers_dict[1] = torch.arange(num_devices).reshape(self.mesh_shape).flatten().tolist() + return process_numbers_dict + + def mix_gather_cost(self, num_bytes): + num_devices = reduce(operator.mul, self.mesh_shape, 1) + return (self.mesh_alpha + self.mesh_beta * (num_devices - 1) / num_devices * num_bytes + 0.1) diff --git a/colossalai/tensor/comm_spec.py b/colossalai/tensor/comm_spec.py index 2910ea843..c8539d38d 100644 --- a/colossalai/tensor/comm_spec.py +++ b/colossalai/tensor/comm_spec.py @@ -79,6 +79,132 @@ def _all_reduce(tensor, comm_spec, async_op=False): return tensor +def _mix_gather(tensor, comm_spec): + ''' + Implement mix gather operation on device mesh based on information provided by comm_spec. + Mix gather is the all-gather operation on all devices in the device_mesh(FlattenDeviceMesh) of the comm_spec. It is + different from _all_gather because _mix_gather does all-gather in two dimensions of device mesh, while _all_gather + only does all-gather in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + ShardingSpec => gather_dim, logical_process_axes + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (1, 1) + RS01 => [b], (1, 1) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + S0S1: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(0,1),(0,2),(0,3),(1,0),(1,1),(1,2),(1,3)] # [(slice_id_f, slice_id_b),...] + mesh_shape = (2,4) + cat_slice = [4,2] + tmp_tensor_list = [(...,shape[f],shape[b]*4,...),(...,shape[f],shape[b]*4,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1),(0,2),(0,3)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1),(1,2),(1,3)), dim=b) + output = torch.cat((tmp_tensor_list[0],tmp_tensor_list[1]), dim=a) + S1S0: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(0,1),(1,0),(1,1),(2,0),(2,1),(3,0),(3,1)] + mesh_shape = (2,4) + cat_slice = [2,4] + tmp_tensor_list = [(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...),(...,shape[f],shape[b]*2,...)] + tmp_tensor_list[0] = torch.cat(((0,0),(0,1)), dim=b) + tmp_tensor_list[1] = torch.cat(((1,0),(1,1)), dim=b) + tmp_tensor_list[2] = torch.cat(((2,0),(2,1)), dim=b) + tmp_tensor_list[3] = torch.cat(((3,0),(3,1)), dim=b) + S10R: + leading_group_dim = 0 + process_group = "[0, 4, 1, 5, 2, 6, 3, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + S01R: + leading_group_dim = 1 + process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" + tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] + ''' + total_slices = comm_spec.device_mesh.mesh_shape[0] + tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] + leading_group_dim = comm_spec.logical_process_axes[0] + assert len(comm_spec.device_mesh.process_groups_dict) == 1 + _, process_group = comm_spec.device_mesh.process_groups_dict[0][0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + + # Global all_gather + dist.all_gather(tensor_list, tensor, group=process_group) + + # This is very ugly. I'm figuring out more elegant methods + tensor_list_sorted = [ + torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices) + ] + for i in range(total_slices): + tensor_list_sorted[i] = tensor_list[process_number_list[i]] + tensor_list = tensor_list_sorted + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() + else: + mesh_shape = comm_spec.device_meshes.mesh_shape + cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + tmp_tensor_shape = list(tensor.shape) + tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] + tmp_tensor_shape = torch.Size(tmp_tensor_shape) + tmp_tensor_list = [ + torch.zeros(tmp_tensor_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(cat_slice[1]) + ] + for i in range(cat_slice[1]): + tmp_tensor_list[i] = torch.cat(tuple(tensor_list[i * cat_slice[0]:(i + 1) * cat_slice[0]]), + comm_spec.gather_dim[0]).contiguous() + output = torch.cat(tuple(tmp_tensor_list), comm_spec.gather_dim[1]).contiguous() + + return output + + +def _mix_split(tensor, comm_spec): + ''' + Implement mix split operation. Mix split is only called for the backward of mix gather (Use ctx to keep consistent) + Mix split shards the tensor on device mesh based on information provided by comm_spec. It is different from split + because _mix_split shards the tensor in two dimensions of device mesh, while _split only shards in one dimension. + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => [b, f], (1, 0) + S1S0 => [b, f], (0, 1) + S01R => [f], (0, 0) + RS01 => [b], (0, 0) + Example: + mesh_shape = (2,4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} + ''' + mesh_shape = comm_spec.device_meshes.mesh_shape + dim = comm_spec.gather_dim + total_slices = comm_spec.device_mesh.mesh_shape[0] + + # Get global rank + rank = dist.get_rank() + + leading_group_dim = comm_spec.logical_process_axes[0] + process_number_list = comm_spec.device_meshes.process_number_dict[leading_group_dim] + rank = process_number_list.index(rank) + + if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: + length = tensor.shape[dim[0]] // total_slices + start = length * rank + output = torch.narrow(tensor, dim[0], start, length).contiguous() + else: + tensor_shape = [tensor.shape[dim[0]], tensor.shape[dim[1]]] + rank_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] + length = [tensor_shape[0] // rank_slice[0], tensor_shape[1] // rank_slice[1]] + start = [(rank % rank_slice[0]) * length[0], (rank // rank_slice[0]) * length[1]] + tmp_output = torch.narrow(tensor, dim[0], start[0], length[0]).contiguous() + output = torch.narrow(tmp_output, dim[1], start[1], length[1]).contiguous() + + return output + + class _ReduceGrad(torch.autograd.Function): """ A customized communication operation which forward is an identity operation, @@ -204,6 +330,22 @@ class _AllToAll(torch.autograd.Function): return _all_to_all(grad_outputs, ctx.comm_spec), None +class _MixGatherForwardMixSplitBackward(torch.autograd.Function): + + @staticmethod + def symbolic(graph, input_): + return _mix_gather(input_) + + @staticmethod + def forward(ctx, input_, comm_spec): + ctx.comm_spec = comm_spec + return _mix_gather(input_, comm_spec) + + @staticmethod + def backward(ctx, grad_output): + return _mix_split(grad_output, ctx.comm_spec), None + + def reduce_grad(input_, comm_spec): return _ReduceGrad.apply(input_, comm_spec) @@ -224,12 +366,17 @@ def all_to_all(input_, comm_spec): return _AllToAll.apply(input_, comm_spec) +def mixgather_forward_split_backward(input_, comm_spec): + return _MixGatherForwardMixSplitBackward.apply(input_, comm_spec) + + class CollectiveCommPattern(Enum): GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' ALLREDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' + MIXGATHER_FWD_SPLIT_BWD = "mixgather_fwd_split_bwd" class CommSpec: @@ -255,7 +402,8 @@ class CommSpec: gather_dim=None, shard_dim=None, logical_process_axis=None, - forward_only=False): + forward_only=False, + mix_gather=False): self.comm_pattern = comm_pattern self.sharding_spec = sharding_spec self.gather_dim = gather_dim @@ -263,8 +411,14 @@ class CommSpec: self.logical_process_axis = logical_process_axis self.forward_only = forward_only if isinstance(self.logical_process_axis, list): - self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh - self.logical_process_axis = 0 + if not mix_gather: + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + self.logical_process_axis = 0 + else: + self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes + self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh + # Create a new member `logical_process_axes` to distinguish from original flatten + self.logical_process_axes = logical_process_axis else: self.device_mesh = self.sharding_spec.device_mesh @@ -289,6 +443,10 @@ class CommSpec: elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + res_list.append(f"comm_pattern:MIXGATHER_FWD_SPLIT_BWD, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_asex:{self.logical_process_axes})") return ''.join(res_list) @@ -324,6 +482,11 @@ class CommSpec: forward_communication_cost = 10 backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + if self.comm_pattern == CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: + # no need for axis because all devices are used in mix_gather + forward_communication_cost = self.device_mesh.mix_gather_cost(comm_size) + backward_communication_cost = 10 + if self.forward_only: cost_dict["forward"] = forward_communication_cost cost_dict["backward"] = 0 @@ -356,4 +519,5 @@ pattern_to_func_dict = { CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: reduce_input, CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, + CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: mixgather_forward_split_backward, } diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index d5d28db0f..d566e3515 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -7,7 +7,7 @@ import torch from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException -from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator from .comm_spec import * @@ -328,6 +328,59 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): pass return valid_spec_dict + def get_all_mix_gather_spec(self, source_spec: ShardingSpec, + orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: + ''' + S0S1 -> RR + S1S0 -> RR + S01R -> RR + RS01 -> RR + ''' + valid_spec_dict = {} + comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict): + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S10, R) -> (R, R) + if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S10) -> (R, R) + if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + comm_spec = CommSpec(comm_pathern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=self.forward_only, + mix_gather=True) + cost_dict = comm_spec.get_comm_cost() + new_dim_partition_dict = {} + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + for phase, cost in cost_dict.items(): + cost_dict[phase] = cost + orig_cost_dict[phase] + valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) + except ShardingSpecException: + pass + + return valid_spec_dict + def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with one step transform, and diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index c5ffc9fb5..0c2ead630 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -90,6 +90,31 @@ def shard_simulator(target_pair, legal_sharding_dims): return shard_list_list +def mix_gather_simulator(f_target_pair, b_target_pair): + ''' + Assume index of f and b target pairs are 'f' and 'b' + S0S1 => Input: (f, [0]), (b, [1]) Output: [b, f], (1, 0) + S1S0 => Input: (f, [1]), (b, [0]) Output: [b, f], (0, 1) + S01R => Input: (f, [0, 1]), (b, []) Output: [f], (1, 1) + RS01 => Input: (f, []), (b, [0, 1]) Output: [b], (1, 1) + S10R => Input: (f, [0, 1]), (b, []) Output: [f], (0, 0) + RS10 => Input: (f, []), (b, [0, 1]) Output: [b], (0, 0) + ''' + if f_target_pair[1] and b_target_pair[1]: + leading_dim = b_target_pair[1] > f_target_pair[1] + return [b_target_pair[0], f_target_pair[0]], [int(leading_dim), int(leading_dim ^ 1)] + if f_target_pair[1]: + leading_dim = f_target_pair[1][0] < f_target_pair[1][1] + return [ + f_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + if b_target_pair[1]: + leading_dim = b_target_pair[1][0] < b_target_pair[1][1] + return [ + b_target_pair[0], + ], [int(leading_dim), int(leading_dim)] + + # The function is credited to PyTorch Team def named_params_with_colotensor( module: nn.Module, diff --git a/tests/test_tensor/test_mix_gather.py b/tests/test_tensor/test_mix_gather.py new file mode 100644 index 000000000..c1ab30601 --- /dev/null +++ b/tests/test_tensor/test_mix_gather.py @@ -0,0 +1,333 @@ +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.core import global_context as gpc +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec +from colossalai.tensor.sharding_spec import ShardingSpec +from colossalai.tensor.utils import mix_gather_simulator +from colossalai.utils import free_port + + +def check_mix_gather_S0S1(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0]) + b_target_pair = (b, [1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [4, 2] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S0S1(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + dim_partition_dict = {0: [0], 1: [1]} + + tensor_slice = [tensor_width // 2, tensor_width // 4] # (4, 2) + rank_slice = 4 + f_start = (rank // rank_slice) * tensor_slice[0] + b_start = (rank % rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + # DistSpec: + # shard_sequence: S0,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [1]} + # DistSpec: + # shard_sequence: R,S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S1S0(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [1]) + b_target_pair = (b, [0]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_slice = [2, 4] + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S1S0(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + tensor_slice = [tensor_width // 4, tensor_width // 2] # (4, 2) + rank_slice = 4 + f_start = (rank % rank_slice) * tensor_slice[0] + b_start = (rank // rank_slice) * tensor_slice[1] + tensor_to_comm = tensor_to_check[f_start:f_start + tensor_slice[0], + b_start:b_start + tensor_slice[1]].contiguous().cuda() + + dim_partition_dict = {0: [1], 1: [0]} + + # DistSpec: + # shard_sequence: S1,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + # DistSpec: + # shard_sequence: R,S0 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_S01R(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + (f, b) = (0, 1) + f_target_pair = (f, [0, 1]) + b_target_pair = (b, []) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[rank:rank + 1, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_S01R(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[rank:rank + rank_stride, :].contiguous().cuda() + + dim_partition_dict = {0: [0, 1]} + + # DistSpec: + # shard_sequence: S01, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {0: [0]} + + # DistSpec: + # shard_sequence: S1, R + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=0, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_mix_gather_RS01(device_mesh, rank): + tensor_to_check = torch.arange(64).reshape((8, 8)).cuda() + + (f, b) = (0, 1) + f_target_pair = (f, []) + b_target_pair = (b, [0, 1]) + gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) + tensor_to_comm = tensor_to_check[:, rank:rank + 1].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + source_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + comm_spec = CommSpec(CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axes, + forward_only=True, + mix_gather=True) + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_two_all_gather_RS01(device_mesh, rank): + tensor_width = 8 + tensor_to_check = torch.arange(int(tensor_width * tensor_width)).reshape((tensor_width, tensor_width)).cuda() + + rank_stride = tensor_width // 8 + tensor_to_comm = tensor_to_check[:, rank:rank + rank_stride].contiguous().cuda() + + dim_partition_dict = {1: [0, 1]} + + # DistSpec: + # shard_sequence: R, S01 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:0) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=1) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + dim_partition_dict = {1: [0]} + + # DistSpec: + # shard_sequence: R, S1 + # device_mesh_shape: (2, 4) + sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) + + # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) + comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + sharding_spec, + gather_dim=1, + logical_process_axis=0) + + tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm) + + assert tensor_to_comm.equal(tensor_to_check) + + +def check_comm(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + physical_mesh_id = torch.arange(0, 8) + assert rank == gpc.get_global_rank() + + mesh_shape = (2, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True, need_flatten=True) + + check_mix_gather_S0S1(device_mesh, rank) + + check_two_all_gather_S0S1(device_mesh, rank) + + check_mix_gather_S1S0(device_mesh, rank) + + check_two_all_gather_S1S0(device_mesh, rank) + + check_mix_gather_S01R(device_mesh, rank) + + check_two_all_gather_S01R(device_mesh, rank) + + check_mix_gather_RS01(device_mesh, rank) + + check_two_all_gather_RS01(device_mesh, rank) + + +@pytest.mark.skip(reason="Skip because the check functions assume 8 GPUS but CI only have 4 GPUs") +def test_mix_gather(): + world_size = 8 + run_func = partial(check_comm, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mix_gather()