diff --git a/tests/components_to_test/bert.py b/tests/components_to_test/bert.py index 224ae5147..e8d202b69 100644 --- a/tests/components_to_test/bert.py +++ b/tests/components_to_test/bert.py @@ -37,9 +37,11 @@ def get_training_components(): num_head = 4 sequence_length = 12 num_layer = 2 + vocab_size = 30524 def bert_model_builder(checkpoint): - config = BertConfig(gradient_checkpointing=checkpoint, + config = BertConfig(vocab_size=vocab_size, + gradient_checkpointing=checkpoint, hidden_size=hidden_dim, intermediate_size=hidden_dim * 4, num_attention_heads=num_head, diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 4a8c8f8d1..4d2b1a4aa 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -85,7 +85,7 @@ def set_seed(seed): torch.backends.cudnn.deterministic = True -def run_1d_col_tp(model_name): +def run_1d_hybrid_tp(model_name): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -96,12 +96,12 @@ def run_1d_col_tp(model_name): model = model_builder(checkpoint=True) if 'bert' == model_name: - parallel_action_list_col = [ + parallel_action_list_row = [ ParallelAction(priority=1, - compute_pattern=ComputePattern.TP1DCol_Linear, + compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ] - spec_col = TensorSpec(parallel_action_list_col) + spec_linear_row = TensorSpec(parallel_action_list_row) parallel_action_list_embedding_col = [ ParallelAction(priority=1, @@ -110,13 +110,28 @@ def run_1d_col_tp(model_name): ] spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) + parallel_action_list_embedding_row = [ + ParallelAction(priority=1, + compute_pattern=ComputePattern.TP1DRow_Embedding, + parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_embedding_row = TensorSpec(parallel_action_list_embedding_row) + for name, p in model.colo_named_parameters(): if not isinstance(p, ColoTensor): continue #print(name) - if 'classifier' in name and ('weight' in name or 'bias' in name): - p.set_spec(spec_col) - if '_embeddings' in name and 'weight' in name: + # num_class = type_vocab_size = 2 | (8, 2) + if 'classifier' in name and 'weight' in name: + p.set_spec(spec_linear_row) + # num_class = vocab_size = 30524 | (30524, 8) + if 'word_embeddings' in name and 'weight' in name: + p.set_spec(spec_embedding_row) + # num_class = seq_len = 512 | (512, 8) + if 'position_embeddings' in name and 'weight' in name: + p.set_spec(spec_embedding_row) + # num_class = type_vocab_size = 2 | (2, 8) + if 'token_type_embeddings' in name and 'weight' in name: p.set_spec(spec_embedding_col) elif "simple_net" == model_name: parallel_action_list_row = [ @@ -334,15 +349,15 @@ def run_1d_row_tp(model_name: str): def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - for name in ['bert', 'simple_net']: + for name in ['simple_net']: run_1d_row_tp(name) - run_1d_col_tp(name) + for name in ['bert', 'simple_net']: + run_1d_hybrid_tp(name) @pytest.mark.dist -# FIXME(jzy) world size = 4 will fialed -# @pytest.mark.parametrize('world_size', [4]) -@parameterize('world_size', [1]) +@pytest.mark.parametrize('world_size', [1, 4]) +#@parameterize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_model(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port())