mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[tensor] refactor parallel action (#1007)
* refactor parallel action * polish unit tests
This commit is contained in:
@@ -41,7 +41,7 @@ class Conv1D(nn.Module):
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
@@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
@@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||
def init_1d_row(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
@@ -26,7 +26,7 @@ def init_1d_row(weight):
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
@@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
@@ -26,7 +26,7 @@ def init_1d_row_spec(model):
|
||||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
@@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
@@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
|
@@ -20,19 +20,15 @@ from _utils import set_seed
|
||||
def init_1d_row_linear(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, gather_out=True):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [
|
||||
ParallelAction(priority=1,
|
||||
compute_pattern=ComputePattern.TP1D,
|
||||
parallel_mode=ParallelMode.PARALLEL_1D,
|
||||
gather_out=gather_out)
|
||||
])
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
@@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True):
|
||||
def init_1d_row_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
@@ -48,7 +44,7 @@ def init_1d_row_embedding(weight):
|
||||
def init_1d_col_embedding(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
|
||||
|
Reference in New Issue
Block a user