[Tensor] init ColoParameter (#914)

This commit is contained in:
Jiarui Fang
2022-05-06 12:57:14 +08:00
committed by GitHub
parent 193d629311
commit ab95ec9aea
6 changed files with 77 additions and 44 deletions

View File

@@ -38,17 +38,23 @@ def run_1d_col_tp():
model = model_builder(checkpoint=True)
parallel_action_list_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_row = TensorSpec(parallel_action_list_row)
parallel_action_list_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DCol_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_col = TensorSpec(parallel_action_list_col)
parallel_action_list_embedding_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DCol_Embedding,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
@@ -125,6 +131,9 @@ def test_model_parameters():
param_cnt += 1
assert param_cnt == 5
for name, colo_p in model.colo_named_parameters():
assert colo_p.is_model_data()
param_cnt = 0
for name, p in model.named_parameters(recurse=False):
param_cnt += 1
@@ -175,12 +184,16 @@ def run_1d_row_tp():
model = model_builder(checkpoint=True)
parallel_action_list = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D)
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec = TensorSpec(parallel_action_list)
parallel_action_list_embedding_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, parallel_mode=ParallelMode.PARALLEL_1D)
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Embedding,
parallel_mode=ParallelMode.PARALLEL_1D)
]
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
@@ -243,6 +256,7 @@ def run_dist(rank, world_size, port):
run_1d_row_tp()
run_1d_col_tp()
@pytest.mark.dist
@parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use()
@@ -252,6 +266,6 @@ def test_simple_net(world_size):
if __name__ == '__main__':
test_simple_net()
# test_model_parameters()
# test_simple_net()
test_model_parameters()
# test_colo_optimizer()