[tensor] refactor parallel action (#1007)

* refactor parallel action

* polish unit tests
This commit is contained in:
ver217
2022-05-20 20:19:58 +08:00
committed by GitHub
parent 9e3d602dba
commit a3b66f6def
12 changed files with 45 additions and 77 deletions

View File

@@ -41,7 +41,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -26,7 +26,7 @@ def init_1d_row(weight):
def init_1d_col(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)

View File

@@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
@@ -26,7 +26,7 @@ def init_1d_row_spec(model):
def init_1d_col_spec(model):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)

View File

@@ -20,19 +20,15 @@ from _utils import set_seed
def init_1d_row_linear(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
def init_1d_col_linear(weight, gather_out=True):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1D,
parallel_mode=ParallelMode.PARALLEL_1D,
gather_out=gather_out)
])
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True):
def init_1d_row_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
@@ -48,7 +44,7 @@ def init_1d_row_embedding(weight):
def init_1d_col_embedding(weight):
spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)