[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee 2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@ -188,7 +188,7 @@ class NodeHandler(ABC):
remove_strategy_list = [] remove_strategy_list = []
for strategy in self.strategies_vector: for strategy in self.strategies_vector:
shard_axis_list = [] shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1 last_axis = len(self.device_mesh.shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items(): for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor): if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axes in sharding_spec.dim_partition_dict.items(): for dim, shard_axes in sharding_spec.dim_partition_dict.items():

View File

@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def collate_strategies(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
device_mesh_is_1d = True device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False device_mesh_is_1d = False
if device_mesh_is_1d: if device_mesh_is_1d:
@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb # Sb = Sb x Sb
# can be None as it is only for 1D device mesh # can be None as it is only for 1D device mesh
# only for 1D device mesh # only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1: if len(self.device_mesh.shape) == 1:
mesh_dim = 0 mesh_dim = 0
else: else:
mesh_dim = self.device_mesh.mesh_shape.index(1) mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim)) strategy_list.append(self.split_one_batch_dim(mesh_dim))
else: else:
# for 2D device mesh # for 2D device mesh

View File

@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec # make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence) sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim() tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert sharding_len == tensor_num_dim, \ assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'

View File

@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items(): for key, weight in state_dict.items():
ret_block = None ret_block = None
ret_block_size = 0 ret_block_size = 0
if is_distributed_tensor(weight): if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight) weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.
@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
continue continue
# If the states are stored as DTensors, mark isDTensor as true. # If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor: if is_distributed_tensor(state_tensor):
isDTensor = True isDTensor = True
state_size += calculate_tensor_size(state_tensor) state_size += calculate_tensor_size(state_tensor)

View File

@ -1,5 +1,5 @@
from types import MethodType from types import MethodType
from typing import Callable, Optional, Union from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -173,7 +173,7 @@ class LazyTensor(torch.Tensor):
self.clean() self.clean()
return _convert_cls(self, target) return _convert_cls(self, target)
def distribute(self, layout: Layout) -> torch.Tensor: def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args: Args:
@ -537,7 +537,10 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod @staticmethod
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
@ -547,7 +550,7 @@ class LazyInitContext:
""" """
def apply_fn(name: str, p: LazyTensor): def apply_fn(name: str, p: LazyTensor):
p.distribute(layout_dict[name]) p.distribute(device_mesh, sharding_spec_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)

View File

@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec):
''' '''
Implement all gather operation on device mesh based on information provided by comm_spec. Implement all gather operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
tensor_list = [ tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis])
] ]
# without this contiguous operation, the all gather may get some unexpected results. # without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous() tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group) dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output return output
def _split(tensor, comm_spec): def _split(tensor, comm_spec):
''' '''
Implement shard operation on device mesh based on information provided by comm_spec. Implement shard operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, _ in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
dim = comm_spec.shard_dim dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * rank_list.index(dist.get_rank()) start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous() output = torch.narrow(tensor, dim, start, length).contiguous()
return output return output
def _all_to_all(tensor, comm_spec): def _all_to_all(tensor, comm_spec):
''' '''
Implement all to all operation on device mesh based on information provided by comm_spec. Implement all to all operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list: world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = list(tensor.shape)
new_shape = torch.Size(new_shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
output_tensor_list = [ new_shape = torch.Size(new_shape)
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
] dim = comm_spec.shard_dim
dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // world_size
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
input_tensor_list = [ group = process_group
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) dist.all_to_all(output_tensor_list, input_tensor_list, group)
] output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
group = process_group return output
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec, async_op=False): def _all_reduce(tensor, comm_spec, async_op=False):
''' '''
Implement all reduce operation on device mesh based on information provided by comm_spec. Implement all reduce operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
if not tensor.is_contiguous(): if not tensor.is_contiguous():
tensor = tensor.contiguous() tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor return tensor
def _mix_gather(tensor, comm_spec): def _mix_gather(tensor, comm_spec):
@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec):
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
''' '''
total_slices = comm_spec.device_mesh.mesh_shape[0] total_slices = comm_spec.device_mesh.shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0] leading_group_dim = comm_spec.logical_process_axes[0]
assert len(comm_spec.device_mesh.process_groups_dict) == 1 assert len(comm_spec.device_mesh.process_groups_dict) == 1
@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec):
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
else: else:
mesh_shape = comm_spec.device_meshes.mesh_shape mesh_shape = comm_spec.device_meshes.shape
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
tmp_tensor_shape = list(tensor.shape) tmp_tensor_shape = list(tensor.shape)
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec):
# [4, 5, 6, 7]] # [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
''' '''
mesh_shape = comm_spec.device_meshes.mesh_shape mesh_shape = comm_spec.device_meshes.shape
dim = comm_spec.gather_dim dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.mesh_shape[0] total_slices = comm_spec.device_mesh.shape[0]
# Get global rank # Get global rank
rank = dist.get_rank() rank = dist.get_rank()
@ -414,7 +411,7 @@ class CommSpec:
self.forward_only = forward_only self.forward_only = forward_only
if isinstance(self.logical_process_axis, list): if isinstance(self.logical_process_axis, list):
if not mix_gather: if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh self.device_mesh = self.sharding_spec.device_mesh.flatten()
self.logical_process_axis = 0 self.logical_process_axis = 0
else: else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes

View File

@ -24,12 +24,12 @@ class CommSpec:
''' '''
Communication spec is used to record the communication action. It converts the communication spec Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis to determine the buffer shape, and logical_process_axis
Argument: Argument:
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered. gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded. shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
@ -37,7 +37,7 @@ class CommSpec:
def __init__(self, def __init__(self,
comm_pattern: CollectiveCommPattern, comm_pattern: CollectiveCommPattern,
process_groups_dict: Dict, process_group_dict: Dict,
gather_dim: int = None, gather_dim: int = None,
shard_dim: int = None, shard_dim: int = None,
logical_process_axis: int = None): logical_process_axis: int = None):
@ -45,7 +45,7 @@ class CommSpec:
self.gather_dim = gather_dim self.gather_dim = gather_dim
self.shard_dim = shard_dim self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis self.logical_process_axis = logical_process_axis
self.process_groups_dict = process_groups_dict self.process_group_dict = process_group_dict
def __repr__(self): def __repr__(self):
res_list = ["CommSpec:("] res_list = ["CommSpec:("]
@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement all gather operation on device mesh based on information provided by comm_spec. Implement all gather operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: world_size = dist.get_world_size(process_group)
if dist.get_rank() in rank_list: tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
tensor_list = [ # without this contiguous operation, the all gather may get some unexpected results.
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) tensor = tensor.contiguous()
] dist.all_gather(tensor_list, tensor, group=process_group)
# without this contiguous operation, the all gather may get some unexpected results. output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
tensor = tensor.contiguous() return output
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor: torch.Tensor, comm_spec: CommSpec): def _split(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement shard operation on device mesh based on information provided by comm_spec. Implement shard operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list: dim = comm_spec.shard_dim
if dist.get_rank() in rank_list: length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
dim = comm_spec.shard_dim start = length * dist.get_rank(process_group)
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) output = torch.narrow(tensor, dim, start, length).contiguous()
start = length * rank_list.index(dist.get_rank()) return output
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement all to all operation on device mesh based on information provided by comm_spec. Implement all to all operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: world_size = dist.get_world_size(process_group)
if dist.get_rank() in rank_list: new_shape = list(tensor.shape)
new_shape = list(tensor.shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = torch.Size(new_shape)
new_shape = torch.Size(new_shape) output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
output_tensor_list = [ dim = comm_spec.shard_dim
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) length = tensor.shape[comm_spec.shard_dim] // world_size
] input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
dim = comm_spec.shard_dim group = process_group
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) dist.all_to_all(output_tensor_list, input_tensor_list, group)
input_tensor_list = [ output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) return output
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
''' '''
Implement all reduce operation on device mesh based on information provided by comm_spec. Implement all reduce operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: if not tensor.is_contiguous():
if dist.get_rank() in rank_list: tensor = tensor.contiguous()
if not tensor.is_contiguous(): dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
tensor = tensor.contiguous() return tensor
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
class _ReduceGrad(torch.autograd.Function): class _ReduceGrad(torch.autograd.Function):
@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
def forward(ctx, input_, comm_spec): def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec) output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_groups_dict=comm_spec.process_groups_dict, process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim, gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim, shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis) logical_process_axis=comm_spec.logical_process_axis)

View File

@ -14,24 +14,21 @@ class Layout:
Attributes: Attributes:
device_mesh: the device mesh to store the tensor distributed. device_mesh: the device mesh to store the tensor distributed.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded. sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor. global_shape: the entire shape of the global tensor.
""" """
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
entire_shape: torch.Size):
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.device_type = device_type
self.sharding_spec = sharding_spec self.sharding_spec = sharding_spec
self.entire_shape = entire_shape self.global_shape = global_shape
self._sanity_check() self._sanity_check()
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(f'{self.sharding_spec}') return hash(f'{self.sharding_spec}')
def get_sharded_shape_per_device(self): def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape) sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1) shard_partitions = reduce(operator.mul, mesh_list, 1)
@ -56,7 +53,7 @@ class Layout:
# make sure that the sharding for a dimension is divisible by the number of devices # make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items(): for dim, shard_list in sharding_spec.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim] tensor_dim_size = self.global_shape[dim]
num_devices = 1 num_devices = 1
for element in shard_list: for element in shard_list:

View File

@ -3,10 +3,8 @@ from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta): class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self): def __init__(self):
self._options = None self._options = None
@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R] # [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout) rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items(): for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair) shard_list = all_gather_simulator(target_pair)
index = target_pair[0] index = target_pair[0]
@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta):
logical_process_axis = target_pair[1][-1] logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec( comm_spec = CommSpec(
comm_pattern, comm_pattern,
process_groups_dict=process_groups_dict, process_group_dict=process_group_dict,
gather_dim=gather_dim, gather_dim=gather_dim,
# shard_dim will be used during backward # shard_dim will be used during backward
shard_dim=gather_dim, shard_dim=gather_dim,
@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R] # [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout) rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta):
''' '''
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1): for f_index in range(tensor_dims - 1):
@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = f_index shard_dim = f_index
logical_process_axis = b_target_pair[1][-1] logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern, comm_spec = CommSpec(comm_pattern,
process_groups_dict, process_group_dict=process_group_dict,
gather_dim=gather_dim, gather_dim=gather_dim,
shard_dim=shard_dim, shard_dim=shard_dim,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
pass pass
@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]} dim_partition_dict = {0: [0]}
# [S0,R,R] # [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout) rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use. # legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = index shard_dim = index
logical_process_axis = shard_list[-1] logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern, comm_spec = CommSpec(comm_pattern,
process_groups_dict, process_group_dict=process_group_dict,
gather_dim=shard_dim, gather_dim=shard_dim,
shard_dim=shard_dim, shard_dim=shard_dim,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta):
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
pass pass
@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]} dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]} dim_partition_target = {0: [0, 1]}
@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [R,S01,R] # [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source, sharding_spec=sharding_spec_source,
entire_shape=entire_shape) global_shape=global_shape)
# [S01,R,R] # [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target, sharding_spec=sharding_spec_target,
entire_shape=entire_shape) global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
# [S0,R,R] # [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source, sharding_spec=sharding_spec_source,
entire_shape=entire_shape) global_shape=global_shape)
# [R,S0,R] # [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target, sharding_spec=sharding_spec_target,
entire_shape=entire_shape) global_shape=global_shape)
if rank in (0, 1): if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1) sharded_tensor_0 = torch.zeros(2, 1)

View File

@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
# legal sharding dims means the mesh_id is still available to use. # legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items(): for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list: for element in shard_list:
legal_sharding_dims.remove(element) legal_sharding_dims.remove(element)
@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
""" """
input_shape = compute_shape(comm_spec.sharding_spec) input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape) input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2) peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel alloc_numel += output_numel
if discard_input: if discard_input:
@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor # generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec) input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape) input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
alloc_numel += output_numel alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel) peak_numel = max(peak_numel, alloc_numel)
if discard_input: if discard_input:

View File

@ -195,7 +195,7 @@ class ShardingSpec:
def __repr__(self): def __repr__(self):
res_list = ["DistSpec:"] res_list = ["DistSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
return ' '.join(res_list) return ' '.join(res_list)
def _sanity_check(self): def _sanity_check(self):
@ -222,7 +222,7 @@ class ShardingSpec:
num_devices = 1 num_devices = 1
for element in shard_list: for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element] num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0: if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError( raise ShardingNotDivisibleError(
@ -288,7 +288,7 @@ class ShardingSpec:
sharded_shape = list(self.entire_shape) sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items(): for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1) shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[ assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'

View File

@ -1 +0,0 @@
from colossalai.tensor.d_tensor.api import to_distributed_tensor

View File

@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory):
if __name__ == "__main__": if __name__ == "__main__":
run_test( test_evoformer_block()
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)

View File

@ -22,7 +22,7 @@ from tests.kit.model_zoo import model_zoo
@parameterize('use_safetensors', [False, True]) @parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn() bert_model = model_fn()
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@parameterize('shard', [True, False]) @parameterize('shard', [True, False])
@parameterize('model_name', ['transformers_gpt']) @parameterize('model_name', ['transformers_gpt'])
def exam_state_dict(placement_policy, shard: bool, model_name: str): def exam_state_dict(placement_policy, shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy) plugin = GeminiPlugin(placement_policy=placement_policy)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)

View File

@ -8,18 +8,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_device_mesh(): def test_device_mesh():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
# [8, 9, 10,11], # [8, 9, 10,11],
# [12,13,14,15]] # [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
assert device_mesh.convert_map[5] == [1, 1] assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.convert_map[11] == [2, 3] assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
def check_1d_device_mesh(): def check_1d_device_mesh():

View File

@ -20,16 +20,12 @@ def check_layer(rank, world_size, port):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict
for mesh_dim, pgs in logical_pg_dict.items(): for axis in range(len(mesh_shape)):
for index, pg in enumerate(pgs): tensor = torch.ones(4).cuda()
if rank in pg: pg = device_mesh.get_process_group(axis=axis)
tensor = torch.ones(4).cuda() dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
group = logical_process_groups[mesh_dim][index][1] assert tensor.equal(tensor_to_check)
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy() gpc.destroy()

View File

@ -1,3 +1,5 @@
from typing import List
import torch import torch
from numpy import isin from numpy import isin
from torch.fx import GraphModule from torch.fx import GraphModule
@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen): def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None):
# must turn on eval mode to ensure the output is consistent # must turn on eval mode to ensure the output is consistent
model.eval() model.eval()
inputs = data_gen()
if ignore_data is not None:
# drop the ignore_data key
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
try: try:
kwargs = data_gen() meta_args = {k: v.to('meta') for k, v in inputs.items()}
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
gm = symbolic_trace(model, meta_args=meta_args) gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
# run forward # run forward
inputs = data_gen()
non_fx_out = model(**inputs) non_fx_out = model(**inputs)
fx_out = gm(**inputs) fx_out = gm(**inputs)

View File

@ -15,7 +15,7 @@ SEQ_LENGTH = 16
def test_albert(): def test_albert():
sub_registry = model_zoo.get_sub_registry('transformers_albert') sub_registry = model_zoo.get_sub_registry('transformers_albert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn() model = model_fn()
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn)

View File

@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo
def test_bert(): def test_bert():
sub_registry = model_zoo.get_sub_registry('transformers_bert') sub_registry = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn() model = model_fn()
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -47,7 +47,7 @@ def test_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers') sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn() data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn) trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize() torch.cuda.synchronize()

View File

@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def test_gpt(): def test_gpt():
sub_registry = model_zoo.get_sub_registry('transformers_gpt') sub_registry = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn() model = model_fn()
# TODO: support the following models # TODO: support the following models
@ -21,7 +21,7 @@ def test_gpt():
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
continue continue
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def test_opt(): def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt') sub_registry = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn() model = model_fn()
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn)

View File

@ -12,9 +12,14 @@ from tests.kit.model_zoo import model_zoo
def test_t5(): def test_t5():
sub_registry = model_zoo.get_sub_registry('transformers_t5') sub_registry = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
if name == "transformers_t5_for_conditional_generation":
# cannot trace for loss function yet
# so we use a data gen which does not produce labels
data_gen_fn = sub_registry.get('transformers_t5')[1]
model = model_fn() model = model_fn()
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -56,7 +56,7 @@ def test_timm_models():
sub_model_zoo = model_zoo.get_sub_registry('timm') sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn() data = data_gen_fn()
if attribute is not None and attribute.has_control_flow: if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()} meta_args = {k: v.to('meta') for k, v in data.items()}

View File

@ -16,7 +16,7 @@ def test_torchaudio_models():
sub_model_zoo = model_zoo.get_sub_registry('torchaudio') sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
model = model_fn() model = model_fn()
trace_and_compare(model, trace_and_compare(model,
data_gen_fn, data_gen_fn,

View File

@ -53,7 +53,7 @@ def test_torchrec_deepfm_models():
deepfm_models = model_zoo.get_sub_registry('deepfm') deepfm_models = model_zoo.get_sub_registry('deepfm')
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items():
data = data_gen_fn() data = data_gen_fn()
if attribute is not None and attribute.has_control_flow: if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()} meta_args = {k: v.to('meta') for k, v in data.items()}

View File

@ -53,7 +53,7 @@ def test_torchrec_dlrm_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
dlrm_models = model_zoo.get_sub_registry('dlrm') dlrm_models = model_zoo.get_sub_registry('dlrm')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items():
data = data_gen_fn() data = data_gen_fn()
# dlrm_interactionarch is not supported # dlrm_interactionarch is not supported

View File

@ -10,7 +10,7 @@ def test_torchvision_models():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tv_sub_registry = model_zoo.get_sub_registry('torchvision') tv_sub_registry = model_zoo.get_sub_registry('torchvision')
for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items():
data = data_gen_fn() data = data_gen_fn()
if model_attribute is not None and model_attribute.has_stochastic_depth_prob: if model_attribute is not None and model_attribute.has_stochastic_depth_prob:

View File

@ -6,6 +6,7 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from colossalai.device.device_mesh import DeviceMesh
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor import to_global from colossalai.tensor.d_tensor import to_global
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print(f'{model.__class__.__name__} pass') print(f'{model.__class__.__name__} pass')
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
sharding_spec_dict: dict) -> None:
state = model.state_dict() state = model.state_dict()
distributed_state = distributed_model.state_dict() distributed_state = distributed_model.state_dict()

View File

@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return dim return dim
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape) shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, return target_sharding_spec
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
return layout
def _get_current_name(prefix: str, name: str) -> str: def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.') return f'{prefix}.{name}'.lstrip('.')
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: def generate_sharding_spec_dict(model: nn.Module) -> dict:
layout_dict = {} sharding_spec_dict = {}
@torch.no_grad() @torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''): def generate_recursively(module: nn.Module, prefix: str = ''):
@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
# initialize tensors directly attached to the current module # initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor): if isinstance(param, LazyTensor):
layout = make_layout(device_mesh, param) sharding_spec = make_sharding_spec(param)
layout_dict[_get_current_name(prefix, name)] = layout sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
for name, buf in module.named_buffers(recurse=False): for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor): if isinstance(buf, LazyTensor):
layout = make_layout(device_mesh, buf) sharding_spec = make_sharding_spec(buf)
layout_dict[_get_current_name(prefix, name)] = layout sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
generate_recursively(model) generate_recursively(model)
return layout_dict return sharding_spec_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
@ -75,7 +71,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
for name, entry in sub_model_zoo.items(): for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models # TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
continue continue
print_rank_0(name) print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx = LazyInitContext() ctx = LazyInitContext()
with ctx: with ctx:
deferred_model = model_fn() deferred_model = model_fn()
layout_dict = generate_layout_dict(deferred_model, device_mesh) sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
ctx.distribute(deferred_model, layout_dict, verbose=True) ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
assert_dist_model_equal(model, deferred_model, layout_dict) assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
def run_dist(rank, world_size, port) -> None: def run_dist(rank, world_size, port) -> None:

View File

@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset):
sub_model_zoo = model_zoo.get_sub_registry(subset) sub_model_zoo = model_zoo.get_sub_registry(subset)
for name, entry in sub_model_zoo.items(): for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models # TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
continue continue
check_lazy_init(entry, verbose=True) check_lazy_init(entry, verbose=True)

View File

@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port): def check_comm(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@ -150,24 +133,22 @@ def check_comm(rank, world_size, port):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
process_groups_dict = device_mesh.process_groups_dict
process_group_dict = device_mesh._process_group_dict[rank]
# test all gather # test all gather
check_all_gather(process_groups_dict, rank) check_all_gather(process_group_dict, rank)
# test shard # test shard
check_shard(process_groups_dict, rank) check_shard(process_group_dict, rank)
# test all to all # test all to all
check_all_to_all(process_groups_dict, rank) check_all_to_all(process_group_dict, rank)
# test all reduce # test all reduce
check_all_reduce_fwd(process_groups_dict, rank) check_all_reduce_fwd(process_group_dict, rank)
check_all_reduce_bwd(process_groups_dict, rank) check_all_reduce_bwd(process_group_dict, rank)
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
gpc.destroy() gpc.destroy()

View File

@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port):
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout) dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
if rank == 0: if rank == 0:
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))

View File

@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16)) global_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter() layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4).reshape(2, 2) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2) # device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout) rst_dict = layout_converter.all_gather_transform_layouts(layout)
@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
layout_all2all = Layout(device_mesh=device_mesh, layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_all2all,
entire_shape=entire_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R # shard_sequence: S0,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
shard_layout = Layout(device_mesh=device_mesh, shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_shard,
entire_shape=entire_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R # shard_sequence: R,S01,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port):
assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1 assert comm_action_sequence[2].logical_process_axis == 1
# checkout cached_spec_pairs_transform_path # checkout chached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R # shard_sequence: R,S01,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
original_tensor = torch.rand(entire_shape).cuda() original_tensor = torch.rand(global_shape).cuda()
# tensor_to_apply: [R, S01, R] # tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)

View File

@ -1,9 +1,10 @@
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],

View File

@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo # the mesh is in the following topo
# [[0, 1], # [[0, 1],
# [2, 3]] # [2, 3]]
physical_mesh_id = torch.arange(0, 4).reshape(2, 2) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
row_id = rank // 2 row_id = rank // 2

View File

@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec(): def test_sharding_spec():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],