diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py new file mode 100644 index 000000000..22bbb1d2f --- /dev/null +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -0,0 +1,556 @@ +import math +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import numpy as np +import torch + +from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.tensor.d_tensor.comm_spec import * +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.sharding_spec import ShardingSpecException +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator + +from .sharding_spec import ShardingSpec +from .utils import get_comm_cost + +__all__ = ['LayoutConverter', 'LayoutConverterOptions', 'set_layout_converting_options'] + + +@dataclass +class LayoutConverterOptions: + """ + LayoutConverterOptions is a dataclass which specifies the preferences for shape consistency. + """ + # TODO: layout converter option is not implemented yet + pass + + +def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: + shape_consistency_manager = LayoutConverter() + global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) + global_layout = Layout(device_mesh=layout.device_mesh, + device_type=layout.device_type, + sharding_spec=global_sharding_spec, + entire_shape=layout.entire_shape) + with torch.no_grad(): + global_tensor = shape_consistency_manager.apply(distributed_tensor, layout, global_layout) + return global_tensor + + +def set_layout_converting_options(options: LayoutConverterOptions): + """ + Configure the shape consistency manager via function call. + """ + manager = LayoutConverter() + manager.options = options + + +class LayoutConverter(metaclass=SingletonMeta): + + def __init__(self): + self._options = None + self._forward_only = False + self.cached_solution = {} + + @property + def options(self): + return self._options + + @options.setter + def options(self, options_: LayoutConverterOptions): + assert isinstance(options_, LayoutConverterOptions) + self._options = options_ + + @property + def forward_only(self): + return self._forward_only + + @forward_only.setter + def forward_only(self, value): + assert isinstance(value, bool) + self._forward_only = value + + def all_gather_transform_layouts(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with single all-gather operation. + For the all-gather operation, we just care about the S dimension. + + Argument: + source_layout: the layout to be transformed. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-gather operation. + + Example: + layout_converter = LayoutConverter() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + dim_partition_dict = {0: [0], 1: [1]} + + # [S0,S1,R] + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + + rst_dict = layout_converter.all_gather_transform_layouts(layout) + for layout, comm_spec in rst_dict.items(): + print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}') + + Output: + [R, S1, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:0, shard_dim:0, logical_process_axis:0) + [S0, R, R]: CommSpec:(comm_pattern:GATHER_FWD_SPLIT_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + source_spec = source_layout.sharding_spec + process_groups_dict = source_layout.device_mesh.process_groups_dict + for target_pair in source_spec.dim_partition_dict.items(): + shard_list = all_gather_simulator(target_pair) + index = target_pair[0] + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if shard_list: + new_dim_partition_dict[index] = shard_list + else: + new_dim_partition_dict.pop(index) + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + gather_dim = index + logical_process_axis = target_pair[1][-1] + comm_spec = CommSpec( + comm_pattern, + process_groups_dict=process_groups_dict, + gather_dim=gather_dim, + # shard_dim will be used during backward + shard_dim=gather_dim, + logical_process_axis=logical_process_axis) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) + new_layout = Layout(device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) + + valid_spec_dict[new_layout] = comm_spec + except ShardingSpecException: + pass + return valid_spec_dict + + def all_to_all_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with single all-to-all operation. + For the all-to-all operation, we just care about the pairs containing S dimension. + + Argument: + source_layout(Layout): the layout to be transformed. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single all-to-all operation. + + Example: + layout_converter = LayoutConverter() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + dim_partition_dict = {0: [0], 1: [1]} + + # [S0,S1,R] + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + rst_dict = layout_converter.all_to_all_transform_layout(layout) + + for layout, comm_spec in rst_dict.items(): + print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}') + + Output: + [S01, R, R]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:0, logical_process_axis: 1) + [R, S1, S0]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:0, shard_dim:2, logical_process_axis: 0) + [S0, R, S1]: CommSpec:(comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, gather_dim:1, shard_dim:2, logical_process_axis: 1) + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + process_groups_dict = source_layout.device_mesh.process_groups_dict + source_spec = source_layout.sharding_spec + tensor_dims = source_spec.dims + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + # skip (R, R) cases + if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: + continue + else: + if f_index in source_spec.dim_partition_dict: + # skip (S01, R) -> (R, S01) is NOT allowed + if len(source_spec.dim_partition_dict[f_index]) >= 2: + continue + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + # skip (R, S01) -> (S01, R) is NOT allowed + if len(source_spec.dim_partition_dict[b_index]) >= 2: + continue + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + # skip (S1, S0) -> S10 + if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: + continue + f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) + f_index = f_target_pair[0] + b_index = b_target_pair[0] + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + if len(f_shard_list) < len(f_target_pair[1]): + gather_dim = f_index + shard_dim = b_index + logical_process_axis = f_target_pair[1][-1] + else: + gather_dim = b_index + shard_dim = f_index + logical_process_axis = b_target_pair[1][-1] + comm_spec = CommSpec(comm_pattern, + process_groups_dict, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis) + + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if f_shard_list: + new_dim_partition_dict[f_index] = f_shard_list + else: + new_dim_partition_dict.pop(f_index) + if b_shard_list: + new_dim_partition_dict[b_index] = b_shard_list + else: + new_dim_partition_dict.pop(b_index) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) + new_layout = Layout(device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) + valid_spec_dict[new_layout] = comm_spec + except ShardingSpecException: + pass + + return valid_spec_dict + + def shard_transform_layout(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with single shard operation. + For the sharding operation, we just care about legal sharding dimensions. + + Argument: + source_layout(Layout): the layout to be transformed. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with single shard operation. + + Example: + layout_converter = LayoutConverter() + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + + dim_partition_dict = {0: [0]} + + # [S0,R,R] + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + rst_dict = layout_converter.shard_transform_layout(layout) + + for layout, comm_spec in rst_dict.items(): + print(f'{layout.sharding_spec.sharding_sequence}: {comm_spec}') + + Output: + [S01, R, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:0, shard_dim:0, logical_process_axis:1) + [S0, S1, R]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:1, shard_dim:1, logical_process_axis:1) + [S0, R, S1]: CommSpec:(comm_pattern:SPLIT_FWD_GATHER_BWD, gather_dim:2, shard_dim:2, logical_process_axis:1) + ''' + valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + source_spec = source_layout.sharding_spec + process_groups_dict = source_layout.device_mesh.process_groups_dict + + # legal sharding dims means the mesh_id is still available to use. + legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] + for dim, shard_list in source_spec.dim_partition_dict.items(): + for element in shard_list: + legal_sharding_dims.remove(element) + + if len(legal_sharding_dims) == 0: + return valid_spec_dict + + tensor_dims = source_spec.dims + + for index in range(tensor_dims): + if index not in source_spec.dim_partition_dict: + shard_list_list = shard_simulator((index, []), legal_sharding_dims) + else: + shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) + if not shard_list_list: + continue + for shard_list in shard_list_list: + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[index] = shard_list + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + shard_dim = index + logical_process_axis = shard_list[-1] + comm_spec = CommSpec(comm_pattern, + process_groups_dict, + gather_dim=shard_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis) + + # generate new sharding spec + try: + new_sharding_spec = ShardingSpec(dim_size=source_spec.dims, + dim_partition_dict=new_dim_partition_dict) + new_layout = Layout(device_mesh=source_layout.device_mesh, + sharding_spec=new_sharding_spec, + device_type=source_layout.device_type, + entire_shape=source_layout.entire_shape) + valid_spec_dict[new_layout] = comm_spec + except ShardingSpecException: + pass + return valid_spec_dict + + def get_all_one_step_transform_spec(self, source_layout: Layout) -> Dict[Layout, CommSpec]: + ''' + Get all valid layouts from source_layout with one step transform. + + Note: + all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, + and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, + we could safely put them together. + + Argument: + source_layout(Layout): the layout to be transformer. + + Return: + valid_spec_dict(Dict[Layout, CommSpec]): all valid layouts from source_layout with one step transform. + ''' + valid_spec_dict = {} + valid_spec_dict.update(self.all_gather_transform_layouts(source_layout)) + valid_spec_dict.update(self.all_to_all_transform_layout(source_layout)) + valid_spec_dict.update(self.shard_transform_layout(source_layout)) + return valid_spec_dict + + def layout_converting(self, source_layout: Layout, + target_layout: Layout) -> Tuple[List[Layout], List[CommSpec], float]: + ''' + This method will find a path to transform source_layout to target_layout with + a greedy algorithm. + The basic idea is: + Step1: + Generate all one-step transform sequences from source_layout. + Step2: + Pick the 'best' layout following the heuristic function. + Step3: + Repeat above steps until the source layout transform to target layout. + + Additionally, to avoid repeating the path search in runtime, we cached all solved path + in auto parallel strategy building time, which could handle most of cases in runtime. + + Args: + source_layout(Layout): the layout to be transformed. + target_layout(Layout): the layout to be achieved after a serious of transforms. + + Return: + transform_path(List[Layout]): The transform path from source_layout to target_layout, + it contains the source_layout and target_layout. + comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the layout converting in order. + + Example: + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + + # [R,S01,R] + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # [S01,R,R] + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) + transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + print(transform_path_str) + + output: + [R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R] + ''' + source_spec = source_layout.sharding_spec + target_spec = target_layout.sharding_spec + MAX_TRANSFORM_STEPS = 20 + total_steps = 0 + transform_path = [] + comm_action_sequence = [] + spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + + if spec_pairs in self.cached_solution: + return self.cached_solution[spec_pairs] + + # We do nothing if the sharding spec is all the same. + if source_spec.spec_diff(target_spec) == 0: + self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence) + return ( + transform_path, + comm_action_sequence, + ) + + temp_sharding_layout = source_layout + + transform_path.append(temp_sharding_layout) + # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms + while total_steps <= MAX_TRANSFORM_STEPS: + valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_layout) + best_difference_score = math.inf + + for layout, comm_spec in valid_transform_spec_dict.items(): + sharding_spec = layout.sharding_spec + spec_difference = sharding_spec.spec_diff(target_spec) + + if spec_difference == 0: + transform_path.append(layout) + comm_action_sequence.append(comm_spec) + self.cached_solution[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence) + + if spec_difference < best_difference_score: + temp_sharding_layout = layout + temp_comm_spec = comm_spec + best_difference_score = spec_difference + + transform_path.append(temp_sharding_layout) + comm_action_sequence.append(temp_comm_spec) + + total_steps += 1 + + raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") + + def get_total_comm_cost(self, source_layout: Layout, target_layout: Layout) -> Dict[str, float]: + ''' + Get the total communication cost of the layout converting process. + ''' + transform_path, comm_action_sequence = self.layout_converting(source_layout, target_layout) + total_cost = {'forward': 0.0, 'backward': 0.0, 'total': 0.0} + for layout, comm_spec in zip(transform_path, comm_action_sequence): + cost_dict = get_comm_cost(layout, comm_spec, self.forward_only) + for key in total_cost: + total_cost[key] += cost_dict[key] + return total_cost + + def apply(self, tensor: torch.Tensor, source_layout: Layout, target_layout: Layout) -> torch.Tensor: + ''' + Apply target_layout to tensor with source layout, the transform path is generated by the + layout_converting method. + + Argument: + tensor (torch.Tensor): The tensor to be redistributed. + source_layout(Layout): The source layout of the tensor. + target_layout (Layout): The tensor will be redistributed to the target_layout. + + Example: + layout_converter = LayoutConverter() + dim_partition_source = {0: [0]} + dim_partition_target = {1: [0]} + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1, + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + entire_shape = (4, 4, 4) + + # [S0,R,R] + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # [R,S0,R] + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + if rank in (0, 1): + sharded_tensor_0 = torch.zeros(2, 1) + sharded_tensor_1 = torch.ones(2, 1) + # tensor([[0., 1.], + # [0., 1.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + if rank in (2, 3): + sharded_tensor_0 = torch.ones(2, 1) * 2 + sharded_tensor_1 = torch.ones(2, 1) * 3 + # tensor([[2., 3.], + # [2., 3.]]) + tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() + + # converted_tensor: [R, S0, R] + converted_tensor = layout_converter.apply(tensor_to_comm, source_layout, target_layout) + print(converted_tensor) + + Output in rank0 and rank1: + tensor([[0.], + [0.], + [2.], + [2.]]) + + Output in rank2 and rank3: + tensor([[1.], + [1.], + [3.], + [3.]]) + ''' + _, comm_action_sequence = self.layout_converting(source_layout, target_layout) + for comm_spec in comm_action_sequence: + tensor = comm_spec.covert_spec_to_action(tensor) + return tensor diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py new file mode 100644 index 000000000..644bb6306 --- /dev/null +++ b/colossalai/tensor/d_tensor/utils.py @@ -0,0 +1,66 @@ +import operator +from functools import reduce +from typing import Dict + +from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec +from colossalai.tensor.d_tensor.layout import Layout + + +def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]: + ''' + This method is used to compute the communication cost for a given layout and comm_spec. + + For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is a tiny cost. + + Args: + layout: the layout of the tensor. + comm_spec: the comm_spec to instruct the communication operation. + forward_only: if it is True, we will just count the forward communication cost. + If it is False, we will count both forward and backward communication cost. + ''' + comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1) + device_mesh = layout.device_mesh + comm_pattern = comm_spec.comm_pattern + logical_process_axis = comm_spec.logical_process_axis + cost_dict = {} + + if comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: + # the comm size for all gather is the size of the gathered tensor + gather_dim = comm_spec.gather_dim + all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] + all_gather_size = device_mesh.mesh_shape[all_gather_axis] + comm_size_for_all_gather = comm_size * all_gather_size + forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) + # give a tiny cost to shard + backward_communication_cost = 100 + + if comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: + forward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis) + # grad should have same shape as input tensor + # all to all operation has same logical process axis as forward. + backward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis) + + if comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: + forward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis) + backward_communication_cost = 0 + + if comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: + forward_communication_cost = 0 + backward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis) + + if comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: + # give a tiny cost to shard + forward_communication_cost = 100 + backward_communication_cost = device_mesh.all_gather_cost(comm_size, logical_process_axis) + + if forward_only: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = 0 + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + else: + cost_dict["forward"] = forward_communication_cost + cost_dict["backward"] = backward_communication_cost + cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"] + + return cost_dict diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py new file mode 100644 index 000000000..70cf8726d --- /dev/null +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -0,0 +1,206 @@ +import math +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.device.device_mesh import DeviceMesh +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern +from colossalai.tensor.d_tensor.layout import Layout +from colossalai.tensor.d_tensor.layout_converter import LayoutConverter +from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + +entire_shape = torch.Size((64, 32, 16)) +layout_converter = LayoutConverter() +physical_mesh_id = torch.arange(0, 4).reshape(2, 2) +mesh_shape = (2, 2) + + +def check_one_step_transform(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # [[0, 1], + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (2, 2) + sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec, + entire_shape=entire_shape) + + rst_dict = layout_converter.all_gather_transform_layouts(layout) + + assert '[R, S1, R]' in [ + str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() + ] + assert '[S0, R, R]' in [ + str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys() + ] + + dim_partition_dict_all2all = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) + layout_all2all = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_all2all, + entire_shape=entire_shape) + + rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) + + assert '[S01, R, R]' in [ + str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() + ] + assert '[R, S1, S0]' in [ + str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() + ] + assert '[S0, R, S1]' in [ + str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys() + ] + + dim_partition_shard = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) + shard_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_shard, + entire_shape=entire_shape) + + rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) + + assert '[S01, R, R]' in [ + str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() + ] + assert '[S0, S1, R]' in [ + str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() + ] + assert '[S0, R, S1]' in [ + str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys() + ] + + +def check_layout_converting(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) + + # check transform path + transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) + assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + + # check comm action sequence + # all-gather(S01) -> S0 + assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD + assert comm_action_sequence[0].gather_dim == 1 + assert comm_action_sequence[0].logical_process_axis == 1 + + # all-to-all(R, S0) -> [S0, R] + assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD + assert comm_action_sequence[1].gather_dim == 1 + assert comm_action_sequence[1].shard_dim == 0 + assert comm_action_sequence[1].logical_process_axis == 0 + + # shard(S0) -> [S01] + assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD + assert comm_action_sequence[2].shard_dim == 0 + assert comm_action_sequence[2].logical_process_axis == 1 + + # checkout chached_spec_pairs_transform_path + assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path + assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence + + comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) + + assert comm_cost['forward'] == comm_cost['backward'] + assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward']) + + +def check_layout_converting_apply(rank, world_size, port): + disable_existing_loggers() + launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) + source_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_source, + entire_shape=entire_shape) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) + target_layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=sharding_spec_target, + entire_shape=entire_shape) + + original_tensor = torch.rand(entire_shape).cuda() + + # tensor_to_apply: [R, S01, R] + tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) + + # tensor_to_check: [S01, R, R] + tensor_to_check = original_tensor.narrow(0, rank * 16, 16) + + converted_tensor = layout_converter.apply(tensor_to_apply, source_layout, target_layout) + assert converted_tensor.equal(tensor_to_check) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_layout_converter(): + world_size = 4 + run_func = partial(check_one_step_transform, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + run_func = partial(check_layout_converting, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_layout_converter()