mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
@@ -71,6 +71,7 @@ class Linear1D(torch.nn.Module):
|
||||
@LAYERS.register_module
|
||||
class Classifier1D(ParallelLayer):
|
||||
"""RowLinear with given weight"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
@@ -127,8 +128,8 @@ class Classifier1D(ParallelLayer):
|
||||
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output = output + self.bias
|
||||
if self.bias is not None:
|
||||
output = output + self.bias
|
||||
return output
|
||||
|
||||
|
||||
@@ -152,6 +153,7 @@ class Linear1D_Col(ParallelLayer):
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
:type gather_output: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
@@ -233,6 +235,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
|
||||
:type parallel_input: bool, optional
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
@@ -302,6 +305,7 @@ class Linear1D_Row(ParallelLayer):
|
||||
class MixedFusedLayerNorm1D(torch.nn.Module):
|
||||
""" Experimental
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps=1e-5):
|
||||
super(MixedFusedLayerNorm1D, self).__init__()
|
||||
|
||||
|
@@ -121,9 +121,10 @@ class classifier_2d(torch.autograd.Function):
|
||||
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
|
||||
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
|
||||
B_grad = B_grad.reshape(ctx.B_shape)
|
||||
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
|
||||
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
|
||||
bias_grad = None
|
||||
if ctx.use_bias:
|
||||
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
|
||||
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
|
||||
|
||||
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
@@ -174,9 +175,9 @@ class Matmul_AB_2D(torch.autograd.Function):
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opb = [None] * 2
|
||||
@@ -279,9 +280,9 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opb = [None] * 2
|
||||
opr = [None] * 2
|
||||
@@ -393,9 +394,9 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
@@ -38,3 +38,9 @@ class PipelineSharedModuleWrapper:
|
||||
for p in module.parameters():
|
||||
setattr(p, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(p, src, group=self.group)
|
||||
|
||||
def register_parameter(self, param: nn.Parameter):
|
||||
assert self.ranks_in_group is not None, f'Rank {gpc.get_local_rank(ParallelMode.PIPELINE)} is not in pipeline_ranks {self.pipeline_ranks}'
|
||||
src = self.ranks_in_group[self.pipeline_ranks[0]]
|
||||
setattr(param, 'pipeline_shared_module_pg', self.group)
|
||||
dist.broadcast(param, src, group=self.group)
|
||||
|
Reference in New Issue
Block a user