mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_test(
|
||||
rank=0,
|
||||
data=get_data(LATENTS_SHAPE),
|
||||
max_memory=None,
|
||||
model=UNet2DModel,
|
||||
print_code=False,
|
||||
print_mem=True,
|
||||
print_est_mem=False,
|
||||
print_progress=False,
|
||||
)
|
||||
test_evoformer_block()
|
||||
|
@@ -22,7 +22,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
|
||||
from transformers import BertForSequenceClassification
|
||||
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
@@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
|
||||
@parameterize('shard', [True, False])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
@@ -8,18 +8,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def test_device_mesh():
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
# [8, 9, 10,11],
|
||||
# [12,13,14,15]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
assert device_mesh.convert_map[5] == [1, 1]
|
||||
assert device_mesh.convert_map[11] == [2, 3]
|
||||
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]]
|
||||
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
|
||||
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
|
||||
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
|
||||
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
|
||||
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
|
||||
|
||||
|
||||
def check_1d_device_mesh():
|
||||
|
@@ -20,16 +20,12 @@ def check_layer(rank, world_size, port):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
|
||||
logical_process_groups = device_mesh.process_groups_dict
|
||||
|
||||
for mesh_dim, pgs in logical_pg_dict.items():
|
||||
for index, pg in enumerate(pgs):
|
||||
if rank in pg:
|
||||
tensor = torch.ones(4).cuda()
|
||||
group = logical_process_groups[mesh_dim][index][1]
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
|
||||
assert tensor.equal(tensor_to_check)
|
||||
for axis in range(len(mesh_shape)):
|
||||
tensor = torch.ones(4).cuda()
|
||||
pg = device_mesh.get_process_group(axis=axis)
|
||||
dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
|
||||
assert tensor.equal(tensor_to_check)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from numpy import isin
|
||||
from torch.fx import GraphModule
|
||||
@@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None):
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
model.eval()
|
||||
|
||||
inputs = data_gen()
|
||||
|
||||
if ignore_data is not None:
|
||||
# drop the ignore_data key
|
||||
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
|
||||
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
meta_args = {k: v.to('meta') for k, v in inputs.items()}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
|
||||
# run forward
|
||||
inputs = data_gen()
|
||||
non_fx_out = model(**inputs)
|
||||
fx_out = gm(**inputs)
|
||||
|
||||
|
@@ -15,7 +15,7 @@ SEQ_LENGTH = 16
|
||||
def test_albert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
@@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_bert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -47,7 +47,7 @@ def test_diffusers():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
trace_and_compare(model_fn, data, output_transform_fn)
|
||||
torch.cuda.synchronize()
|
||||
|
@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_gpt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
|
||||
# TODO: support the following models
|
||||
@@ -21,7 +21,7 @@ def test_gpt():
|
||||
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
|
||||
continue
|
||||
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_opt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
@@ -12,9 +12,14 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_t5():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
if name == "transformers_t5_for_conditional_generation":
|
||||
# cannot trace for loss function yet
|
||||
# so we use a data gen which does not produce labels
|
||||
data_gen_fn = sub_registry.get('transformers_t5')[1]
|
||||
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -56,7 +56,7 @@ def test_timm_models():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('timm')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
|
@@ -16,7 +16,7 @@ def test_torchaudio_models():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
|
@@ -53,7 +53,7 @@ def test_torchrec_deepfm_models():
|
||||
deepfm_models = model_zoo.get_sub_registry('deepfm')
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in deepfm_models.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in deepfm_models.items():
|
||||
data = data_gen_fn()
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
|
@@ -53,7 +53,7 @@ def test_torchrec_dlrm_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
dlrm_models = model_zoo.get_sub_registry('dlrm')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in dlrm_models.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in dlrm_models.items():
|
||||
data = data_gen_fn()
|
||||
|
||||
# dlrm_interactionarch is not supported
|
||||
|
@@ -10,7 +10,7 @@ def test_torchvision_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
tv_sub_registry = model_zoo.get_sub_registry('torchvision')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, model_attribute) in tv_sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, model_attribute) in tv_sub_registry.items():
|
||||
data = data_gen_fn()
|
||||
|
||||
if model_attribute is not None and model_attribute.has_stochastic_depth_prob:
|
||||
|
@@ -6,6 +6,7 @@ import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor import to_global
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
@@ -82,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
|
||||
print(f'{model.__class__.__name__} pass')
|
||||
|
||||
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
|
||||
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
|
||||
sharding_spec_dict: dict) -> None:
|
||||
state = model.state_dict()
|
||||
distributed_state = distributed_model.state_dict()
|
||||
|
||||
|
@@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
|
||||
return dim
|
||||
|
||||
|
||||
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
|
||||
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
|
||||
shard_dim = find_shard_dim(original_tensor.shape)
|
||||
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
return layout
|
||||
return target_sharding_spec
|
||||
|
||||
|
||||
def _get_current_name(prefix: str, name: str) -> str:
|
||||
return f'{prefix}.{name}'.lstrip('.')
|
||||
|
||||
|
||||
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
||||
layout_dict = {}
|
||||
def generate_sharding_spec_dict(model: nn.Module) -> dict:
|
||||
sharding_spec_dict = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_recursively(module: nn.Module, prefix: str = ''):
|
||||
@@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
|
||||
# initialize tensors directly attached to the current module
|
||||
for name, param in module.named_parameters(recurse=False):
|
||||
if isinstance(param, LazyTensor):
|
||||
layout = make_layout(device_mesh, param)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
sharding_spec = make_sharding_spec(param)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
|
||||
for name, buf in module.named_buffers(recurse=False):
|
||||
if isinstance(buf, LazyTensor):
|
||||
layout = make_layout(device_mesh, buf)
|
||||
layout_dict[_get_current_name(prefix, name)] = layout
|
||||
sharding_spec = make_sharding_spec(buf)
|
||||
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
|
||||
|
||||
generate_recursively(model)
|
||||
|
||||
return layout_dict
|
||||
return sharding_spec_dict
|
||||
|
||||
|
||||
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
|
||||
@@ -75,7 +71,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
|
||||
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
|
||||
continue
|
||||
print_rank_0(name)
|
||||
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
|
||||
@@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
|
||||
ctx = LazyInitContext()
|
||||
with ctx:
|
||||
deferred_model = model_fn()
|
||||
layout_dict = generate_layout_dict(deferred_model, device_mesh)
|
||||
ctx.distribute(deferred_model, layout_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, layout_dict)
|
||||
sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
|
||||
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
|
||||
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port) -> None:
|
||||
|
@@ -10,7 +10,7 @@ def test_torchvision_models_lazy_init(subset):
|
||||
sub_model_zoo = model_zoo.get_sub_registry(subset)
|
||||
for name, entry in sub_model_zoo.items():
|
||||
# TODO(ver217): lazy init does not support weight norm, skip these models
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
|
||||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'):
|
||||
continue
|
||||
check_lazy_init(entry, verbose=True)
|
||||
|
||||
|
@@ -122,23 +122,6 @@ def check_all_reduce_bwd(process_groups_dict, rank):
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
|
||||
# tensor to comm
|
||||
tensor_to_comm = torch.ones(2, 2).cuda() * rank
|
||||
|
||||
# reduce through logical process axis 0 at flatten device mesh
|
||||
# tensor to check
|
||||
# tensor([[6., 6.],
|
||||
# [6., 6.]])
|
||||
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
|
||||
|
||||
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
|
||||
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
|
||||
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
|
||||
|
||||
assert tensor_to_comm.equal(tensor_to_check)
|
||||
|
||||
|
||||
def check_comm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
@@ -150,24 +133,22 @@ def check_comm(rank, world_size, port):
|
||||
# [[0, 1,
|
||||
# [2, 3]]
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
process_groups_dict = device_mesh.process_groups_dict
|
||||
|
||||
process_group_dict = device_mesh._process_group_dict[rank]
|
||||
|
||||
# test all gather
|
||||
check_all_gather(process_groups_dict, rank)
|
||||
check_all_gather(process_group_dict, rank)
|
||||
|
||||
# test shard
|
||||
check_shard(process_groups_dict, rank)
|
||||
check_shard(process_group_dict, rank)
|
||||
|
||||
# test all to all
|
||||
check_all_to_all(process_groups_dict, rank)
|
||||
check_all_to_all(process_group_dict, rank)
|
||||
|
||||
# test all reduce
|
||||
check_all_reduce_fwd(process_groups_dict, rank)
|
||||
check_all_reduce_bwd(process_groups_dict, rank)
|
||||
check_all_reduce_fwd(process_group_dict, rank)
|
||||
check_all_reduce_bwd(process_group_dict, rank)
|
||||
|
||||
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
|
||||
# test all reduce in 1D flatten device mesh
|
||||
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
|
@@ -64,7 +64,7 @@ def check_dtensor(rank, world_size, port):
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
|
||||
|
||||
if rank == 0:
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
|
||||
|
@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
global_shape = torch.Size((64, 32, 16))
|
||||
layout_converter = LayoutConverter()
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
|
||||
|
||||
@@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port):
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (2, 2)
|
||||
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec,
|
||||
entire_shape=entire_shape)
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
|
||||
rst_dict = layout_converter.all_gather_transform_layouts(layout)
|
||||
|
||||
@@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port):
|
||||
# shard_sequence: S0,S1,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
|
||||
layout_all2all = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_all2all,
|
||||
entire_shape=entire_shape)
|
||||
layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
|
||||
|
||||
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
|
||||
|
||||
@@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port):
|
||||
# shard_sequence: S0,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
|
||||
shard_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_shard,
|
||||
entire_shape=entire_shape)
|
||||
shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
|
||||
|
||||
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
|
||||
|
||||
@@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port):
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
|
||||
|
||||
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
|
||||
|
||||
@@ -137,7 +122,7 @@ def check_layout_converting(rank, world_size, port):
|
||||
assert comm_action_sequence[2].shard_dim == 0
|
||||
assert comm_action_sequence[2].logical_process_axis == 1
|
||||
|
||||
# checkout cached_spec_pairs_transform_path
|
||||
# checkout chached_spec_pairs_transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
|
||||
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
|
||||
|
||||
@@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port):
|
||||
# shard_sequence: R,S01,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
|
||||
source_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_source,
|
||||
entire_shape=entire_shape)
|
||||
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
|
||||
|
||||
# DistSpec:
|
||||
# shard_sequence: S01,R,R
|
||||
# device_mesh_shape: (4, 4)
|
||||
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
|
||||
target_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=sharding_spec_target,
|
||||
entire_shape=entire_shape)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
|
||||
|
||||
original_tensor = torch.rand(entire_shape).cuda()
|
||||
original_tensor = torch.rand(global_shape).cuda()
|
||||
|
||||
# tensor_to_apply: [R, S01, R]
|
||||
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
|
||||
|
@@ -1,9 +1,10 @@
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
|
||||
import torch
|
||||
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
|
@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
|
||||
# the mesh is in the following topo
|
||||
# [[0, 1],
|
||||
# [2, 3]]
|
||||
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
row_id = rank // 2
|
||||
|
@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
||||
|
||||
|
||||
def test_sharding_spec():
|
||||
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
||||
physical_mesh_id = torch.arange(0, 16)
|
||||
mesh_shape = (4, 4)
|
||||
# [[0, 1, 2, 3],
|
||||
# [4, 5, 6, 7],
|
||||
|
Reference in New Issue
Block a user