[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -3,10 +3,8 @@ from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
@@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self):
self._options = None
@@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items():
@@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
@@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta):
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
process_groups_dict=process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
# shard_dim will be used during backward
shard_dim=gather_dim,
@@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
@@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta):
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1):
@@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]}
# [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
@@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta):
dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
@@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
global_shape=global_shape)
# [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
@@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
# [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
global_shape=global_shape)
# [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
global_shape=global_shape)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)