mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-05 10:10:32 +00:00
parent
ed6426c300
commit
dfaff4e243
@ -37,9 +37,11 @@ def get_training_components():
|
|||||||
num_head = 4
|
num_head = 4
|
||||||
sequence_length = 12
|
sequence_length = 12
|
||||||
num_layer = 2
|
num_layer = 2
|
||||||
|
vocab_size = 30524
|
||||||
|
|
||||||
def bert_model_builder(checkpoint):
|
def bert_model_builder(checkpoint):
|
||||||
config = BertConfig(gradient_checkpointing=checkpoint,
|
config = BertConfig(vocab_size=vocab_size,
|
||||||
|
gradient_checkpointing=checkpoint,
|
||||||
hidden_size=hidden_dim,
|
hidden_size=hidden_dim,
|
||||||
intermediate_size=hidden_dim * 4,
|
intermediate_size=hidden_dim * 4,
|
||||||
num_attention_heads=num_head,
|
num_attention_heads=num_head,
|
||||||
|
@ -85,7 +85,7 @@ def set_seed(seed):
|
|||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
|
||||||
def run_1d_col_tp(model_name):
|
def run_1d_hybrid_tp(model_name):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
@ -96,12 +96,12 @@ def run_1d_col_tp(model_name):
|
|||||||
model = model_builder(checkpoint=True)
|
model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
if 'bert' == model_name:
|
if 'bert' == model_name:
|
||||||
parallel_action_list_col = [
|
parallel_action_list_row = [
|
||||||
ParallelAction(priority=1,
|
ParallelAction(priority=1,
|
||||||
compute_pattern=ComputePattern.TP1DCol_Linear,
|
compute_pattern=ComputePattern.TP1DRow_Linear,
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
]
|
]
|
||||||
spec_col = TensorSpec(parallel_action_list_col)
|
spec_linear_row = TensorSpec(parallel_action_list_row)
|
||||||
|
|
||||||
parallel_action_list_embedding_col = [
|
parallel_action_list_embedding_col = [
|
||||||
ParallelAction(priority=1,
|
ParallelAction(priority=1,
|
||||||
@ -110,13 +110,28 @@ def run_1d_col_tp(model_name):
|
|||||||
]
|
]
|
||||||
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
|
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
|
||||||
|
|
||||||
|
parallel_action_list_embedding_row = [
|
||||||
|
ParallelAction(priority=1,
|
||||||
|
compute_pattern=ComputePattern.TP1DRow_Embedding,
|
||||||
|
parallel_mode=ParallelMode.PARALLEL_1D)
|
||||||
|
]
|
||||||
|
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
|
||||||
|
|
||||||
for name, p in model.colo_named_parameters():
|
for name, p in model.colo_named_parameters():
|
||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
continue
|
continue
|
||||||
#print(name)
|
#print(name)
|
||||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
# num_class = type_vocab_size = 2 | (8, 2)
|
||||||
p.set_spec(spec_col)
|
if 'classifier' in name and 'weight' in name:
|
||||||
if '_embeddings' in name and 'weight' in name:
|
p.set_spec(spec_linear_row)
|
||||||
|
# num_class = vocab_size = 30524 | (30524, 8)
|
||||||
|
if 'word_embeddings' in name and 'weight' in name:
|
||||||
|
p.set_spec(spec_embedding_row)
|
||||||
|
# num_class = seq_len = 512 | (512, 8)
|
||||||
|
if 'position_embeddings' in name and 'weight' in name:
|
||||||
|
p.set_spec(spec_embedding_row)
|
||||||
|
# num_class = type_vocab_size = 2 | (2, 8)
|
||||||
|
if 'token_type_embeddings' in name and 'weight' in name:
|
||||||
p.set_spec(spec_embedding_col)
|
p.set_spec(spec_embedding_col)
|
||||||
elif "simple_net" == model_name:
|
elif "simple_net" == model_name:
|
||||||
parallel_action_list_row = [
|
parallel_action_list_row = [
|
||||||
@ -334,15 +349,15 @@ def run_1d_row_tp(model_name: str):
|
|||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
for name in ['bert', 'simple_net']:
|
for name in ['simple_net']:
|
||||||
run_1d_row_tp(name)
|
run_1d_row_tp(name)
|
||||||
run_1d_col_tp(name)
|
for name in ['bert', 'simple_net']:
|
||||||
|
run_1d_hybrid_tp(name)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
# FIXME(jzy) world size = 4 will fialed
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
# @pytest.mark.parametrize('world_size', [4])
|
#@parameterize('world_size', [1, 4])
|
||||||
@parameterize('world_size', [1])
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_model(world_size):
|
def test_model(world_size):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
Loading…
Reference in New Issue
Block a user