[DTensor] implement layout converter (#3055)

* [DTensor] refactor LayoutConverter for DTensor

* polish code

* polish docstring
This commit is contained in:
YuliangLiu0306
2023-03-10 09:53:52 +08:00
committed by GitHub
parent 89aa7926ac
commit 8e4e8601b7
3 changed files with 828 additions and 0 deletions

View File

@@ -0,0 +1,206 @@
import math
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
entire_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
mesh_shape = (2, 2)
def check_one_step_transform(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# [[0, 1],
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
assert '[R, S1, R]' in [
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
]
assert '[S0, R, R]' in [
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
]
dim_partition_dict_all2all = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
layout_all2all = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_all2all,
entire_shape=entire_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
assert '[S01, R, R]' in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
assert '[R, S1, S0]' in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
assert '[S0, R, S1]' in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
dim_partition_shard = {0: [0]}
# DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
shard_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_shard,
entire_shape=entire_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
assert '[S01, R, R]' in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
assert '[S0, S1, R]' in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
assert '[S0, R, S1]' in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
def check_layout_converting(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
# check transform path
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
# check comm action sequence
# all-gather(S01) -> S0
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
assert comm_action_sequence[0].gather_dim == 1
assert comm_action_sequence[0].logical_process_axis == 1
# all-to-all(R, S0) -> [S0, R]
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
assert comm_action_sequence[1].gather_dim == 1
assert comm_action_sequence[1].shard_dim == 0
assert comm_action_sequence[1].logical_process_axis == 0
# shard(S0) -> [S01]
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
# checkout chached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
assert comm_cost['forward'] == comm_cost['backward']
assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward'])
def check_layout_converting_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
original_tensor = torch.rand(entire_shape).cuda()
# tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
# tensor_to_check: [S01, R, R]
tensor_to_check = original_tensor.narrow(0, rank * 16, 16)
converted_tensor = layout_converter.apply(tensor_to_apply, source_layout, target_layout)
assert converted_tensor.equal(tensor_to_check)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_layout_converter():
world_size = 4
run_func = partial(check_one_step_transform, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
run_func = partial(check_layout_converting, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_layout_converter()