[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis])
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape = torch.Size(new_shape)
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // world_size
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec, async_op=False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
def _mix_gather(tensor, comm_spec):
@@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec):
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
'''
total_slices = comm_spec.device_mesh.mesh_shape[0]
total_slices = comm_spec.device_mesh.shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0]
assert len(comm_spec.device_mesh.process_groups_dict) == 1
@@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec):
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
else:
mesh_shape = comm_spec.device_meshes.mesh_shape
mesh_shape = comm_spec.device_meshes.shape
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
tmp_tensor_shape = list(tensor.shape)
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
@@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec):
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
mesh_shape = comm_spec.device_meshes.mesh_shape
mesh_shape = comm_spec.device_meshes.shape
dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.mesh_shape[0]
total_slices = comm_spec.device_mesh.shape[0]
# Get global rank
rank = dist.get_rank()
@@ -414,7 +411,7 @@ class CommSpec:
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.device_mesh = self.sharding_spec.device_mesh.flatten()
self.logical_process_axis = 0
else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes

View File

@@ -24,12 +24,12 @@ class CommSpec:
'''
Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
@@ -37,7 +37,7 @@ class CommSpec:
def __init__(self,
comm_pattern: CollectiveCommPattern,
process_groups_dict: Dict,
process_group_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None):
@@ -45,7 +45,7 @@ class CommSpec:
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.process_groups_dict = process_groups_dict
self.process_group_dict = process_group_dict
def __repr__(self):
res_list = ["CommSpec:("]
@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape = torch.Size(new_shape)
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // world_size
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
class _ReduceGrad(torch.autograd.Function):
@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_groups_dict=comm_spec.process_groups_dict,
process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)

View File

@@ -14,24 +14,21 @@ class Layout:
Attributes:
device_mesh: the device mesh to store the tensor distributed.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor.
global_shape: the entire shape of the global tensor.
"""
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
entire_shape: torch.Size):
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
self.device_mesh = device_mesh
self.device_type = device_type
self.sharding_spec = sharding_spec
self.entire_shape = entire_shape
self.global_shape = global_shape
self._sanity_check()
def __hash__(self) -> int:
return hash(f'{self.sharding_spec}')
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
@@ -56,7 +53,7 @@ class Layout:
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
tensor_dim_size = self.global_shape[dim]
num_devices = 1
for element in shard_list:

View File

@@ -3,10 +3,8 @@ from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
@@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self):
self._options = None
@@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items():
@@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
@@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta):
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
process_groups_dict=process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
# shard_dim will be used during backward
shard_dim=gather_dim,
@@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
@@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta):
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1):
@@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]}
# [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
@@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta):
dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
@@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
global_shape=global_shape)
# [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
@@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
# [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
global_shape=global_shape)
# [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
global_shape=global_shape)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)

View File

@@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
@@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
@@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:

View File

@@ -195,7 +195,7 @@ class ShardingSpec:
def __repr__(self):
res_list = ["DistSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}")
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
return ' '.join(res_list)
def _sanity_check(self):
@@ -222,7 +222,7 @@ class ShardingSpec:
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
@@ -288,7 +288,7 @@ class ShardingSpec:
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'