mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
Revert "[sync] sync feature/shardformer with develop"
This commit is contained in:
@@ -125,6 +125,23 @@ def check_all_reduce_bwd(process_groups_dict, rank):
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
|
||||
# tensor to comm
|
||||
tensor_to_comm = torch.ones(2, 2).cuda() * rank
|
||||
|
||||
# reduce through logical process axis 0 at flatten device mesh
|
||||
# tensor to check
|
||||
# tensor([[6., 6.],
|
||||
# [6., 6.]])
|
||||
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
|
||||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@@ -136,22 +153,24 @@ def check_comm(rank, world_size, port):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
|
||||
process_group_dict = device_mesh._process_group_dict[rank]
|
||||
process_groups_dict = device_mesh.process_groups_dict
|
||||
|
||||
# test all gather
|
||||
check_all_gather(process_group_dict, rank)
|
||||
check_all_gather(process_groups_dict, rank)
|
||||
|
||||
# test shard
|
||||
check_shard(process_group_dict, rank)
|
||||
check_shard(process_groups_dict, rank)
|
||||
|
||||
# test all to all
|
||||
check_all_to_all(process_group_dict, rank)
|
||||
check_all_to_all(process_groups_dict, rank)
|
||||
|
||||
# test all reduce
|
||||
check_all_reduce_fwd(process_group_dict, rank)
|
||||
check_all_reduce_bwd(process_group_dict, rank)
|
||||
check_all_reduce_fwd(process_groups_dict, rank)
|
||||
check_all_reduce_bwd(process_groups_dict, rank)
|
||||
|
||||
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
|
||||
# test all reduce in 1D flatten device mesh
|
||||
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
|
@@ -31,9 +31,13 @@ def check_dtensor(rank, world_size, port):
|
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||
d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
d_tensor = DTensor(original_tensor, layout)
|
||||
|
||||
assert d_tensor.global_shape == original_tensor.shape
|
||||
assert d_tensor.entire_shape == original_tensor.shape
|
||||
assert d_tensor.data_type == original_tensor.dtype
|
||||
|
||||
if rank in (0, 1):
|
||||
@@ -53,7 +57,12 @@ def check_dtensor(rank, world_size, port):
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||
d_tensor.layout_convert(device_mesh, new_sharding_spec)
|
||||
new_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=new_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
|
||||
d_tensor.layout_convert(new_layout)
|
||||
|
||||
if rank == 0:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
@@ -66,7 +75,7 @@ def check_dtensor(rank, world_size, port):
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||
|
||||
if rank == 0:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
|
@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
global_shape = torch.Size((64, 32, 16))
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
mesh_shape = (2, 2)
|
||||
|
||||
|
||||
@@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port):
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
|
||||
@@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port):
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
|
||||
layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
|
||||
layout_all2all = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_all2all,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
|
||||
|
||||
@@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port):
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
|
||||
shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
|
||||
shard_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_shard,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
|
||||
|
||||
@@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port):
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
|
||||
@@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port):
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
|
||||
original_tensor = torch.rand(global_shape).cuda()
|
||||
original_tensor = torch.rand(entire_shape).cuda()
|
||||
|
||||
# tensor_to_apply: [R, S01, R]
|
||||
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
|
||||
|
@@ -1,10 +1,9 @@
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
|
@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
|
||||
# the mesh is in the following topo
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
row_id = rank // 2
|
||||
|
@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
def test_sharding_spec():
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
|
Reference in New Issue
Block a user