[Tensor] fix test_model (#916)

* polish test_model

* polish
This commit is contained in:
Ziyue Jiang 2022-05-06 18:06:22 +08:00 committed by GitHub
parent ed6426c300
commit dfaff4e243
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 13 deletions

View File

@ -37,9 +37,11 @@ def get_training_components():
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 30524
def bert_model_builder(checkpoint):
config = BertConfig(gradient_checkpointing=checkpoint,
config = BertConfig(vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,

View File

@ -85,7 +85,7 @@ def set_seed(seed):
torch.backends.cudnn.deterministic = True
def run_1d_col_tp(model_name):
def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -96,12 +96,12 @@ def run_1d_col_tp(model_name):
model = model_builder(checkpoint=True)
if 'bert' == model_name:
parallel_action_list_col = [
parallel_action_list_row = [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DCol_Linear,
compute_pattern=ComputePattern.TP1DRow_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_col = TensorSpec(parallel_action_list_col)
spec_linear_row = TensorSpec(parallel_action_list_row)
parallel_action_list_embedding_col = [
ParallelAction(priority=1,
@ -110,13 +110,28 @@ def run_1d_col_tp(model_name):
]
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
parallel_action_list_embedding_row = [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Embedding,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
for name, p in model.colo_named_parameters():
if not isinstance(p, ColoTensor):
continue
#print(name)
if 'classifier' in name and ('weight' in name or 'bias' in name):
p.set_spec(spec_col)
if '_embeddings' in name and 'weight' in name:
# num_class = type_vocab_size = 2 | (8, 2)
if 'classifier' in name and 'weight' in name:
p.set_spec(spec_linear_row)
# num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name:
p.set_spec(spec_embedding_row)
# num_class = seq_len = 512 | (512, 8)
if 'position_embeddings' in name and 'weight' in name:
p.set_spec(spec_embedding_row)
# num_class = type_vocab_size = 2 | (2, 8)
if 'token_type_embeddings' in name and 'weight' in name:
p.set_spec(spec_embedding_col)
elif "simple_net" == model_name:
parallel_action_list_row = [
@ -334,15 +349,15 @@ def run_1d_row_tp(model_name: str):
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for name in ['bert', 'simple_net']:
for name in ['simple_net']:
run_1d_row_tp(name)
run_1d_col_tp(name)
for name in ['bert', 'simple_net']:
run_1d_hybrid_tp(name)
@pytest.mark.dist
# FIXME(jzy) world size = 4 will fialed
# @pytest.mark.parametrize('world_size', [4])
@parameterize('world_size', [1])
@pytest.mark.parametrize('world_size', [1, 4])
#@parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_model(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())