[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee 2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@ -188,7 +188,7 @@ class NodeHandler(ABC):
remove_strategy_list = []
for strategy in self.strategies_vector:
shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1
last_axis = len(self.device_mesh.shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axes in sharding_spec.dim_partition_dict.items():

View File

@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False
if device_mesh_is_1d:
@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1:
if len(self.device_mesh.shape) == 1:
mesh_dim = 0
else:
mesh_dim = self.device_mesh.mesh_shape.index(1)
mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim))
else:
# for 2D device mesh

View File

@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'

View File

@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if is_distributed_tensor(weight):
if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
continue
# If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor:
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)

View File

@ -1,5 +1,5 @@
from types import MethodType
from typing import Callable, Optional, Union
from typing import Callable, Dict, Optional, Union
import torch
import torch.distributed as dist
@ -173,7 +173,7 @@ class LazyTensor(torch.Tensor):
self.clean()
return _convert_cls(self, target)
def distribute(self, layout: Layout) -> torch.Tensor:
def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args:
@ -537,7 +537,10 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module:
def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args:
@ -547,7 +550,7 @@ class LazyInitContext:
"""
def apply_fn(name: str, p: LazyTensor):
p.distribute(layout_dict[name])
p.distribute(device_mesh, sharding_spec_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose)

View File

@ -16,12 +16,12 @@ def _all_gather(tensor, comm_spec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis])
]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
@ -34,12 +34,12 @@ def _split(tensor, comm_spec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
@ -48,20 +48,17 @@ def _all_to_all(tensor, comm_spec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
length = tensor.shape[comm_spec.shard_dim] // world_size
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
@ -72,9 +69,9 @@ def _all_reduce(tensor, comm_spec, async_op=False):
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
process_group = process_groups[comm_spec.logical_process_axis]
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec):
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
'''
total_slices = comm_spec.device_mesh.mesh_shape[0]
total_slices = comm_spec.device_mesh.shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0]
assert len(comm_spec.device_mesh.process_groups_dict) == 1
@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec):
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
else:
mesh_shape = comm_spec.device_meshes.mesh_shape
mesh_shape = comm_spec.device_meshes.shape
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
tmp_tensor_shape = list(tensor.shape)
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec):
# [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
'''
mesh_shape = comm_spec.device_meshes.mesh_shape
mesh_shape = comm_spec.device_meshes.shape
dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.mesh_shape[0]
total_slices = comm_spec.device_mesh.shape[0]
# Get global rank
rank = dist.get_rank()
@ -414,7 +411,7 @@ class CommSpec:
self.forward_only = forward_only
if isinstance(self.logical_process_axis, list):
if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.device_mesh = self.sharding_spec.device_mesh.flatten()
self.logical_process_axis = 0
else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes

View File

@ -24,12 +24,12 @@ class CommSpec:
'''
Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
Argument:
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
@ -37,7 +37,7 @@ class CommSpec:
def __init__(self,
comm_pattern: CollectiveCommPattern,
process_groups_dict: Dict,
process_group_dict: Dict,
gather_dim: int = None,
shard_dim: int = None,
logical_process_axis: int = None):
@ -45,7 +45,7 @@ class CommSpec:
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
self.process_groups_dict = process_groups_dict
self.process_group_dict = process_group_dict
def __repr__(self):
res_list = ["CommSpec:("]
@ -92,12 +92,9 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement all gather operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
# without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group)
@ -109,12 +106,10 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement shard operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list:
if dist.get_rank() in rank_list:
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
@ -123,20 +118,15 @@ def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
'''
Implement all to all operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
length = tensor.shape[comm_spec.shard_dim] // world_size
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
@ -147,9 +137,7 @@ def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = Fals
'''
Implement all reduce operation on device mesh based on information provided by comm_spec.
'''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_groups_dict=comm_spec.process_groups_dict,
process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis)

View File

@ -14,24 +14,21 @@ class Layout:
Attributes:
device_mesh: the device mesh to store the tensor distributed.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor.
global_shape: the entire shape of the global tensor.
"""
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
entire_shape: torch.Size):
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
self.device_mesh = device_mesh
self.device_type = device_type
self.sharding_spec = sharding_spec
self.entire_shape = entire_shape
self.global_shape = global_shape
self._sanity_check()
def __hash__(self) -> int:
return hash(f'{self.sharding_spec}')
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
@ -56,7 +53,7 @@ class Layout:
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
tensor_dim_size = self.global_shape[dim]
num_devices = 1
for element in shard_list:

View File

@ -3,10 +3,8 @@ from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Tuple
import numpy as np
import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self):
self._options = None
@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items():
@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair)
index = target_pair[0]
@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta):
logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec(
comm_pattern,
process_groups_dict=process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
# shard_dim will be used during backward
shard_dim=gather_dim,
@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta):
'''
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1):
@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = f_index
logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern,
process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=gather_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]}
# [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items():
@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = index
logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern,
process_groups_dict,
process_group_dict=process_group_dict,
gather_dim=shard_dim,
shard_dim=shard_dim,
logical_process_axis=logical_process_axis)
@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta):
dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec,
device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape)
global_shape=source_layout.global_shape)
valid_spec_dict[new_layout] = comm_spec
except LayoutException:
pass
@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
global_shape=global_shape)
# [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4)
global_shape = (4, 4, 4)
# [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
global_shape=global_shape)
# [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
global_shape=global_shape)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)

View File

@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
# legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))]
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list:
legal_sharding_dims.remove(element)
@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
"""
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel
if discard_input:
@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]
output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel)
if discard_input:

View File

@ -195,7 +195,7 @@ class ShardingSpec:
def __repr__(self):
res_list = ["DistSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}")
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
return ' '.join(res_list)
def _sanity_check(self):
@ -222,7 +222,7 @@ class ShardingSpec:
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
@ -288,7 +288,7 @@ class ShardingSpec:
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'

View File

@ -1 +0,0 @@
from colossalai.tensor.d_tensor.api import to_distributed_tensor

View File

@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory):
if __name__ == "__main__":
run_test(
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)
test_evoformer_block()

View File

@ -22,7 +22,7 @@ from tests.kit.model_zoo import model_zoo
@parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn()
with shared_tempdir() as tempdir:
@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@parameterize('shard', [True, False])
@parameterize('model_name', ['transformers_gpt'])
def exam_state_dict(placement_policy, shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy)
booster = Booster(plugin=plugin)

View File

@ -8,18 +8,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_device_mesh():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
assert device_mesh.convert_map[5] == [1, 1]
assert device_mesh.convert_map[11] == [2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
def check_1d_device_mesh():

View File

@ -20,15 +20,11 @@ def check_layer(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict
for mesh_dim, pgs in logical_pg_dict.items():
for index, pg in enumerate(pgs):
if rank in pg:
for axis in range(len(mesh_shape)):
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
pg = device_mesh.get_process_group(axis=axis)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
assert tensor.equal(tensor_to_check)
gpc.destroy()

View File

@ -1,3 +1,5 @@
from typing import List
import torch
from numpy import isin
from torch.fx import GraphModule
@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten
from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):
def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None):
# must turn on eval mode to ensure the output is consistent
model.eval()
inputs = data_gen()
if ignore_data is not None:
# drop the ignore_data key
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
meta_args = {k: v.to('meta') for k, v in inputs.items()}
gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
# run forward
inputs = data_gen()
non_fx_out = model(**inputs)
fx_out = gm(**inputs)

View File

@ -15,7 +15,7 @@ SEQ_LENGTH = 16
def test_albert():
sub_registry = model_zoo.get_sub_registry('transformers_albert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)

View File

@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo
def test_bert():
sub_registry = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
if __name__ == '__main__':

View File

@ -47,7 +47,7 @@ def test_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()

View File

@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def test_gpt():
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
# TODO: support the following models
@ -21,7 +21,7 @@ def test_gpt():
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
continue
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
if __name__ == '__main__':

View File

@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)

View File

@ -12,9 +12,14 @@ from tests.kit.model_zoo import model_zoo
def test_t5():
sub_registry = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
if name == "transformers_t5_for_conditional_generation":
# cannot trace for loss function yet
# so we use a data gen which does not produce labels
data_gen_fn = sub_registry.get('transformers_t5')[1]
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
if __name__ == '__main__':

View File

@ -56,7 +56,7 @@ def test_timm_models():
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}

View File

@ -16,7 +16,7 @@ def test_torchaudio_models():
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
model = model_fn()
trace_and_compare(model,
data_gen_fn,

View File

@ -53,7 +53,7 @@ def test_torchrec_deepfm_models():
deepfm_models = model_zoo.get_sub_registry('deepfm')
torch.backends.cudnn.deterministic = True
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}

View File

@ -53,7 +53,7 @@ def test_torchrec_dlrm_models():
torch.backends.cudnn.deterministic = True
dlrm_models = model_zoo.get_sub_registry('dlrm')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items():
data = data_gen_fn()
# dlrm_interactionarch is not supported

View File

@ -10,7 +10,7 @@ def test_torchvision_models():
torch.backends.cudnn.deterministic = True
tv_sub_registry = model_zoo.get_sub_registry('torchvision')
for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items():
data = data_gen_fn()
if model_attribute is not None and model_attribute.has_stochastic_depth_prob:

View File

@ -6,6 +6,7 @@ import numpy as np
import torch
from packaging import version
from colossalai.device.device_mesh import DeviceMesh
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor import to_global
from colossalai.tensor.d_tensor.layout import Layout
@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print(f'{model.__class__.__name__} pass')
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
sharding_spec_dict: dict) -> None:
state = model.state_dict()
distributed_state = distributed_model.state_dict()

View File

@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return dim
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
return layout
return target_sharding_spec
def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.')
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
layout_dict = {}
def generate_sharding_spec_dict(model: nn.Module) -> dict:
sharding_spec_dict = {}
@torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''):
@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor):
layout = make_layout(device_mesh, param)
layout_dict[_get_current_name(prefix, name)] = layout
sharding_spec = make_sharding_spec(param)
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor):
layout = make_layout(device_mesh, buf)
layout_dict[_get_current_name(prefix, name)] = layout
sharding_spec = make_sharding_spec(buf)
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
generate_recursively(model)
return layout_dict
return sharding_spec_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
@ -75,7 +71,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
continue
print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
layout_dict = generate_layout_dict(deferred_model, device_mesh)
ctx.distribute(deferred_model, layout_dict, verbose=True)
assert_dist_model_equal(model, deferred_model, layout_dict)
sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
def run_dist(rank, world_size, port) -> None:

View File

@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset):
sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
continue
check_lazy_init(entry, verbose=True)

View File

@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -150,24 +133,22 @@ def check_comm(rank, world_size, port):
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
process_groups_dict = device_mesh.process_groups_dict
process_group_dict = device_mesh._process_group_dict[rank]
# test all gather
check_all_gather(process_groups_dict, rank)
check_all_gather(process_group_dict, rank)
# test shard
check_shard(process_groups_dict, rank)
check_shard(process_group_dict, rank)
# test all to all
check_all_to_all(process_groups_dict, rank)
check_all_to_all(process_group_dict, rank)
# test all reduce
check_all_reduce_fwd(process_groups_dict, rank)
check_all_reduce_bwd(process_groups_dict, rank)
check_all_reduce_fwd(process_group_dict, rank)
check_all_reduce_bwd(process_group_dict, rank)
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
gpc.destroy()

View File

@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port):
else:
raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
if rank == 0:
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))

View File

@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16))
global_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
layout_all2all = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_all2all,
entire_shape=entire_shape)
layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
shard_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_shard,
entire_shape=entire_shape)
shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
# checkout cached_spec_pairs_transform_path
# checkout chached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
original_tensor = torch.rand(entire_shape).cuda()
original_tensor = torch.rand(global_shape).cuda()
# tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)

View File

@ -1,9 +1,10 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],

View File

@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo
# [[0, 1],
# [2, 3]]
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
row_id = rank // 2

View File

@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],