mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -24,12 +24,12 @@ class CommSpec:
|
||||
'''
|
||||
Communication spec is used to record the communication action. It converts the communication spec
|
||||
to real action which will be used in runtime. It contains comm_pattern to determine the
|
||||
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim
|
||||
communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
|
||||
to determine the buffer shape, and logical_process_axis
|
||||
|
||||
Argument:
|
||||
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec.
|
||||
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
|
||||
process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
|
||||
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
|
||||
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
|
||||
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
|
||||
@@ -37,7 +37,7 @@ class CommSpec:
|
||||
|
||||
def __init__(self,
|
||||
comm_pattern: CollectiveCommPattern,
|
||||
process_groups_dict: Dict,
|
||||
process_group_dict: Dict,
|
||||
gather_dim: int = None,
|
||||
shard_dim: int = None,
|
||||
logical_process_axis: int = None):
|
||||
@@ -45,7 +45,7 @@ class CommSpec:
|
||||
self.gather_dim = gather_dim
|
||||
self.shard_dim = shard_dim
|
||||
self.logical_process_axis = logical_process_axis
|
||||
self.process_groups_dict = process_groups_dict
|
||||
self.process_group_dict = process_group_dict
|
||||
|
||||
def __repr__(self):
|
||||
res_list = ["CommSpec:("]
|
||||
@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all gather operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
tensor_list = [
|
||||
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
# without this contiguous operation, the all gather may get some unexpected results.
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_gather(tensor_list, tensor, group=process_group)
|
||||
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement shard operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, _ in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
start = length * rank_list.index(dist.get_rank())
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||
start = length * dist.get_rank(process_group)
|
||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||
'''
|
||||
Implement all to all operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list)
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [
|
||||
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
|
||||
]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
|
||||
input_tensor_list = [
|
||||
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
|
||||
]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
world_size = dist.get_world_size(process_group)
|
||||
new_shape = list(tensor.shape)
|
||||
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
|
||||
new_shape = torch.Size(new_shape)
|
||||
output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dim = comm_spec.shard_dim
|
||||
length = tensor.shape[comm_spec.shard_dim] // world_size
|
||||
input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
|
||||
group = process_group
|
||||
dist.all_to_all(output_tensor_list, input_tensor_list, group)
|
||||
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
|
||||
return output
|
||||
|
||||
|
||||
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
|
||||
'''
|
||||
Implement all reduce operation on device mesh based on information provided by comm_spec.
|
||||
'''
|
||||
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis]
|
||||
for rank_list, process_group in process_groups_list:
|
||||
if dist.get_rank() in rank_list:
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
|
||||
if not tensor.is_contiguous():
|
||||
tensor = tensor.contiguous()
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
|
||||
return tensor
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
|
||||
def forward(ctx, input_, comm_spec):
|
||||
output = _all_to_all(input_, comm_spec)
|
||||
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
|
||||
process_groups_dict=comm_spec.process_groups_dict,
|
||||
process_group_dict=comm_spec.process_group_dict,
|
||||
gather_dim=comm_spec.shard_dim,
|
||||
shard_dim=comm_spec.gather_dim,
|
||||
logical_process_axis=comm_spec.logical_process_axis)
|
||||
|
@@ -14,24 +14,21 @@ class Layout:
|
||||
|
||||
Attributes:
|
||||
device_mesh: the device mesh to store the tensor distributed.
|
||||
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
|
||||
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
||||
entire_shape: the entire shape of the global tensor.
|
||||
global_shape: the entire shape of the global tensor.
|
||||
"""
|
||||
|
||||
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec,
|
||||
entire_shape: torch.Size):
|
||||
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
|
||||
self.device_mesh = device_mesh
|
||||
self.device_type = device_type
|
||||
self.sharding_spec = sharding_spec
|
||||
self.entire_shape = entire_shape
|
||||
self.global_shape = global_shape
|
||||
self._sanity_check()
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(f'{self.sharding_spec}')
|
||||
|
||||
def get_sharded_shape_per_device(self):
|
||||
sharded_shape = list(self.entire_shape)
|
||||
sharded_shape = list(self.global_shape)
|
||||
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
||||
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
||||
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
||||
@@ -56,7 +53,7 @@ class Layout:
|
||||
|
||||
# make sure that the sharding for a dimension is divisible by the number of devices
|
||||
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
||||
tensor_dim_size = self.entire_shape[dim]
|
||||
tensor_dim_size = self.global_shape[dim]
|
||||
num_devices = 1
|
||||
|
||||
for element in shard_list:
|
||||
|
@@ -3,10 +3,8 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
@@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
|
||||
|
||||
class LayoutConverter(metaclass=SingletonMeta):
|
||||
"""
|
||||
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._options = None
|
||||
@@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
for target_pair in source_spec.dim_partition_dict.items():
|
||||
shard_list = all_gather_simulator(target_pair)
|
||||
index = target_pair[0]
|
||||
@@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
logical_process_axis = target_pair[1][-1]
|
||||
comm_spec = CommSpec(
|
||||
comm_pattern,
|
||||
process_groups_dict=process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
# shard_dim will be used during backward
|
||||
shard_dim=gather_dim,
|
||||
@@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
@@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
dim_partition_dict = {0: [0], 1: [1]}
|
||||
|
||||
# [S0,S1,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
rst_dict = layout_converter.all_to_all_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
source_spec = source_layout.sharding_spec
|
||||
tensor_dims = source_spec.dims
|
||||
for f_index in range(tensor_dims - 1):
|
||||
@@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
shard_dim = f_index
|
||||
logical_process_axis = b_target_pair[1][-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=gather_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
@@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_dict = {0: [0]}
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
rst_dict = layout_converter.shard_transform_layout(layout)
|
||||
|
||||
for layout, comm_spec in rst_dict.items():
|
||||
@@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
valid_spec_dict = {}
|
||||
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
|
||||
source_spec = source_layout.sharding_spec
|
||||
process_groups_dict = source_layout.device_mesh.process_groups_dict
|
||||
|
||||
# the key of the dict is the axis
|
||||
# the value is the process group
|
||||
current_rank = source_layout.device_mesh._global_rank_of_current_process
|
||||
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
|
||||
|
||||
# legal sharding dims means the mesh_id is still available to use.
|
||||
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
|
||||
@@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
shard_dim = index
|
||||
logical_process_axis = shard_list[-1]
|
||||
comm_spec = CommSpec(comm_pattern,
|
||||
process_groups_dict,
|
||||
process_group_dict=process_group_dict,
|
||||
gather_dim=shard_dim,
|
||||
shard_dim=shard_dim,
|
||||
logical_process_axis=logical_process_axis)
|
||||
@@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
new_layout = Layout(device_mesh=source_layout.device_mesh,
|
||||
sharding_spec=new_sharding_spec,
|
||||
device_type=source_layout.device_type,
|
||||
entire_shape=source_layout.entire_shape)
|
||||
global_shape=source_layout.global_shape)
|
||||
valid_spec_dict[new_layout] = comm_spec
|
||||
except LayoutException:
|
||||
pass
|
||||
@@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
dim_partition_source = {1: [0, 1]}
|
||||
dim_partition_target = {0: [0, 1]}
|
||||
@@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [R,S01,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
# [S01,R,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
|
||||
@@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
entire_shape = (4, 4, 4)
|
||||
global_shape = (4, 4, 4)
|
||||
|
||||
# [S0,R,R]
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
# [R,S0,R]
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
global_shape=global_shape)
|
||||
|
||||
if rank in (0, 1):
|
||||
sharded_tensor_0 = torch.zeros(2, 1)
|
||||
|
Reference in New Issue
Block a user