[tensor] refactor parallel action (#1007)

* refactor parallel action

* polish unit tests
This commit is contained in:
ver217 2022-05-20 20:19:58 +08:00 committed by GitHub
parent 9e3d602dba
commit a3b66f6def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 45 additions and 77 deletions

View File

@ -3,12 +3,12 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor
from colossalai.tensor import distspec from colossalai.tensor import distspec
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, Number, convert_to_colo_tensor from ._utils import GeneralTensor, Number, convert_to_colo_tensor
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# mat1:S[1] x mat2:S[0] = Output:P # mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res # beta * input + alpha * All-Reduce(Output) = res
@ -18,7 +18,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# Output:P # Output:P
partial_output = torch.mm(mat1, mat2) partial_output = torch.mm(mat1, mat2)
# Reduce(Output) # Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# input # input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op' assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output output = beta * input_tensor + alpha * output
@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor: alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1] # mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.get_action_by_compute_pattern(ComputePattern.TP1D) parallel_action = mat2.spec.parallel_action
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group())) mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
mat1 = reduce_grad(mat1, parallel_action.parallel_mode) mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]), output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]) ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
if parallel_action.gather_out: if parallel_action.gather_out:
# All-Gather(Output) # All-Gather(Output)

View File

@ -1,10 +1,10 @@
import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input from colossalai.nn.layer.parallel_1d._utils import reduce_input
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
@ -17,7 +17,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor: sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table # Gather splitted lookup table
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
output_parallel = F.embedding(input_tensor, output_parallel = F.embedding(input_tensor,
@ -29,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse=sparse) sparse=sparse)
output_spec = TensorSpec( output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)]) ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output return output
@ -45,10 +44,9 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here # Find index in this shard and mask those not here
# Reduce all # Reduce all
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(parallel_action.parallel_mode) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
num_embeddings_per_partition = weight.size(0) num_embeddings_per_partition = weight.size(0)
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition vocab_end_index = vocab_start_index + num_embeddings_per_partition
@ -72,7 +70,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Mask the output embedding. # Mask the output embedding.
partial_output[input_mask, :] = 0. partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, parallel_action.parallel_mode) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group()))) output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
return output return output

View File

@ -1,16 +1,14 @@
import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist
from typing import Optional from typing import Optional
from colossalai.tensor.op_wrapper import colo_op_impl from colossalai.tensor.op_wrapper import colo_op_impl
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ParallelAction, ColoTensor, distspec
from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
from colossalai.context import ParallelMode
from ._utils import GeneralTensor, convert_to_colo_tensor from ._utils import GeneralTensor, convert_to_colo_tensor
def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor: def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> ColoTensor:
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D)
# Input:S[1] x Weight:S[0] = Output:P # Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res # All-Reduce(Output) + bias = res
# Input:S[1] # Input:S[1]
@ -20,7 +18,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Output:P # Output:P
partial_output = F.linear(input_tensor, weight) partial_output = F.linear(input_tensor, weight)
# Reduce(Output) # Reduce(Output)
output = reduce_input(partial_output, parallel_action.parallel_mode) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias # Bias
if bias is not None: if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op' assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
@ -34,15 +32,16 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output) # All-Gather(Output)
# Input:B # Input:B
parallel_action = weight.spec.get_action_by_compute_pattern(ComputePattern.TP1D) parallel_action = weight.spec.parallel_action
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, parallel_action.parallel_mode) input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
output_parallel = F.linear(input_parallel, weight, bias) output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor( output = ColoTensor.from_torch_tensor(output_parallel,
output_parallel, spec=TensorSpec(
spec=TensorSpec(distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.spec.get_process_group(), [-1],
[ParallelAction(priority=1, parallel_mode=parallel_action.parallel_mode)])) [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)))
if parallel_action.gather_out: if parallel_action.gather_out:
# All-Gather(Output) # All-Gather(Output)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))

View File

@ -28,7 +28,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduction=reduction, reduction=reduction,
label_smoothing=label_smoothing) label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output) return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied elif input_tensor.has_spec(): # Single Model Parallel Applied
if input_tensor.spec.is_1D_col(): if input_tensor.spec.is_1D_col():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output) return ColoTensor.from_torch_tensor(output)

View File

@ -33,3 +33,6 @@ class ColoParameter(ColoTensor):
tensor = tensor.as_subclass(ColoParameter) tensor = tensor.as_subclass(ColoParameter)
tensor.__init__(tensor, requires_grad=requires_grad, spec=spec) tensor.__init__(tensor, requires_grad=requires_grad, spec=spec)
return tensor return tensor
def __repr__(self):
return f'ColoParameter: {torch.Tensor.__repr__(self)}'

View File

@ -45,7 +45,7 @@ class ColoTensor(torch.Tensor):
self._spec = spec self._spec = spec
def has_spec(self) -> bool: def has_spec(self) -> bool:
return self._spec.num_action > 0 return self._spec.parallel_action is not None
def is_model_data(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.MODEL return self._type == TensorType.MODEL

View File

@ -1,26 +1,21 @@
import torch.distributed as dist import torch.distributed as dist
from enum import Enum from enum import Enum
from typing import List from typing import List, Optional
from colossalai.context.parallel_mode import ParallelMode
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
class ComputePattern(Enum): class ComputePattern(Enum):
TP1D = 0 TP1D = 0
ZeRO = 1 TP2D = 1
DP = 2 TP2P5D = 2
TP3D = 3
class ParallelAction(object): class ParallelAction(object):
def __init__(self, def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None:
priority=0, assert isinstance(compute_pattern, ComputePattern)
compute_pattern=ComputePattern.DP,
parallel_mode=ParallelMode.DATA,
gather_out=True) -> None:
self.priority = priority
self.compute_pattern = compute_pattern self.compute_pattern = compute_pattern
self.parallel_mode = parallel_mode
self.gather_out = gather_out self.gather_out = gather_out
@ -48,32 +43,9 @@ class TensorSpec(object):
# We perform Linear Op according to compute pattern of TP1D_Linear. # We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO. # After Linear Op, we split the tensors according to ZeRO.
def __init__(self, dist_spec: _DistSpec, parallel_action_list: List[ParallelAction] = []): def __init__(self, dist_spec: _DistSpec, parallel_action: Optional[ParallelAction] = None):
self._parallel_action_list = parallel_action_list self.parallel_action = parallel_action
self.dist_spec = dist_spec self.dist_spec = dist_spec
self.sort()
@property
def parallel_action_list(self):
return self._parallel_action_list
@property
def num_action(self):
return len(self._parallel_action_list)
@property
def compute_patterns(self):
return [parallel_action.compute_pattern for parallel_action in self._parallel_action_list]
def sort(self):
if len(self._parallel_action_list) > 0:
self._parallel_action_list.sort(key=lambda parallel_action: parallel_action.priority)
def get_action_by_compute_pattern(self, compute_pattern: ComputePattern):
for parallel_action in self._parallel_action_list:
if parallel_action.compute_pattern == compute_pattern:
return parallel_action
return None
def get_process_group(self): def get_process_group(self):
return self.dist_spec.process_group return self.dist_spec.process_group
@ -99,4 +71,4 @@ class TensorSpec(object):
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def has_compute_pattern(self, compute_pattern: ComputePattern): def has_compute_pattern(self, compute_pattern: ComputePattern):
return self.get_action_by_compute_pattern(compute_pattern) is not None return self.parallel_action.compute_pattern == compute_pattern

View File

@ -41,7 +41,7 @@ class Conv1D(nn.Module):
def init_1d_row(weight, bias): def init_1d_row(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
bias.set_spec(spec) bias.set_spec(spec)

View File

@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight): def init_1d_row(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -26,7 +26,7 @@ def init_1d_row(weight):
def init_1d_col(weight): def init_1d_col(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)

View File

@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def init_1d_row_spec(model): def init_1d_row_spec(model):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
@ -26,7 +26,7 @@ def init_1d_row_spec(model):
def init_1d_col_spec(model): def init_1d_col_spec(model):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):

View File

@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias): def init_1d_row(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
def init_1d_col(weight, bias): def init_1d_col(weight, bias):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
bias.set_spec(spec) bias.set_spec(spec)

View File

@ -20,19 +20,15 @@ from _utils import set_seed
def init_1d_row_linear(weight): def init_1d_row_linear(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
def init_1d_col_linear(weight, gather_out=True): def init_1d_col_linear(weight, gather_out=True):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), [ distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ParallelAction(priority=1, ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
compute_pattern=ComputePattern.TP1D,
parallel_mode=ParallelMode.PARALLEL_1D,
gather_out=gather_out)
])
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True):
def init_1d_row_embedding(weight): def init_1d_row_embedding(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)
@ -48,7 +44,7 @@ def init_1d_row_embedding(weight):
def init_1d_col_embedding(weight): def init_1d_col_embedding(weight):
spec = TensorSpec( spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1D, parallel_mode=ParallelMode.PARALLEL_1D)]) ParallelAction(ComputePattern.TP1D))
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
weight.set_spec(spec) weight.set_spec(spec)