[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -10,6 +10,13 @@ from .layers import (
)
__all__ = [
'split_batch_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
"split_batch_2p5d",
"reduce_by_batch_2p5d",
"Linear2p5D",
"LayerNorm2p5D",
"Classifier2p5D",
"PatchEmbedding2p5D",
"Embedding2p5D",
"VocabParallelClassifier2p5D",
"VocabParallelEmbedding2p5D",
]

View File

@@ -24,7 +24,6 @@ def get_parallel_rank(parallel_mode: ParallelMode):
class _Classifier2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(
@@ -98,10 +97,21 @@ class _Classifier2p5D(torch.autograd.Function):
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int,
...], row_rank: int, col_rank: int,
row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor:
def classifier_2p5d(
A: Tensor,
B: Tensor,
bias,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
r"""Classifier.
Args:
@@ -123,9 +133,21 @@ def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: T
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
return _Classifier2p5D.apply(
A,
B,
bias,
tesseract_dim,
out_shape,
row_rank,
col_rank,
row_parallel_mode,
col_parallel_mode,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size,
)
class Matmul_AB_2p5D(torch.autograd.Function):
@@ -153,16 +175,27 @@ class Matmul_AB_2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
# A: [b / dq, s, h / q] -> [(b * s) / dq, h / q]
# B: [h / dq, s / q]
# C: [b / dq, s, s / q] -> [(b * s) / dq, s / q]
assert A.shape[-1] == B.shape[-2], \
'Invalid shapes: A={}, B={} for AB.'.format(A.shape, B.shape)
assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
@@ -182,14 +215,18 @@ class Matmul_AB_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_a = (
tesseract_dim * row_rank
+ tesseract_dim**2 * dep_rank
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
src_b = (
col_rank
+ tesseract_dim**2 * dep_rank
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
opa = [None] * 2
opb = [None] * 2
@@ -205,10 +242,9 @@ class Matmul_AB_2p5D(torch.autograd.Function):
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur], src=src_a + 1, group=row_group, async_op=True)
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
opb[1 - cur] = dist.broadcast(
B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True
)
if opa[cur] is not None:
opa[cur].wait()
@@ -242,14 +278,36 @@ class Matmul_AB_2p5D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
B_grad = Matmul_ATB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
A_grad = Matmul_ABT_2p5D.apply(
output_grad,
B,
ctx.tesseract_dim,
ctx.A_shape,
ctx.row_rank,
ctx.col_rank,
ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size,
)
B_grad = Matmul_ATB_2p5D.apply(
A,
output_grad,
ctx.tesseract_dim,
ctx.B_shape,
ctx.row_rank,
ctx.col_rank,
ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size,
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@@ -278,13 +336,23 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
assert A.shape[-1] == B.shape[-1], \
'Invalid shapes: A={}, B={} for ABT.'.format(A.shape, B.shape)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
assert A.shape[-1] == B.shape[-1], "Invalid shapes: A={}, B={} for ABT.".format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
@@ -304,14 +372,18 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = (
col_rank
+ tesseract_dim**2 * dep_rank
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
src_c = (
tesseract_dim * row_rank
+ tesseract_dim**2 * dep_rank
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
opb = [None] * 2
opr = [None] * 2
@@ -323,10 +395,9 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
for i in range(tesseract_dim):
if i != tesseract_dim - 1:
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + tesseract_dim,
group=col_group,
async_op=True)
opb[1 - cur] = dist.broadcast(
B_list[1 - cur], src=src_b + tesseract_dim, group=col_group, async_op=True
)
if opr[cur] is not None:
opr[cur].wait()
@@ -372,14 +443,36 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_AB_2p5D.apply(output_grad, B, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
B_grad = Matmul_ATB_2p5D.apply(output_grad, A, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
A_grad = Matmul_AB_2p5D.apply(
output_grad,
B,
ctx.tesseract_dim,
ctx.A_shape,
ctx.row_rank,
ctx.col_rank,
ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size,
)
B_grad = Matmul_ATB_2p5D.apply(
output_grad,
A,
ctx.tesseract_dim,
ctx.B_shape,
ctx.row_rank,
ctx.col_rank,
ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size,
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
@@ -408,13 +501,23 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, A: Tensor, B: Tensor, tesseract_dim: int, out_shape: Tuple[int, ...], row_rank: int,
col_rank: int, dep_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int):
assert A.shape[-2] == B.shape[-2], \
'Invalid shapes: A={}, B={} for ATB.'.format(A.shape, B.shape)
def forward(
ctx: Any,
A: Tensor,
B: Tensor,
tesseract_dim: int,
out_shape: Tuple[int, ...],
row_rank: int,
col_rank: int,
dep_rank: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
):
assert A.shape[-2] == B.shape[-2], "Invalid shapes: A={}, B={} for ATB.".format(A.shape, B.shape)
if ctx:
ctx.save_for_backward(A, B)
@@ -434,14 +537,18 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = \
tesseract_dim * row_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = \
col_rank + tesseract_dim ** 2 * dep_rank + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_a = (
tesseract_dim * row_rank
+ tesseract_dim**2 * dep_rank
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
src_c = (
col_rank
+ tesseract_dim**2 * dep_rank
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
opa = [None] * 2
opr = [None] * 2
@@ -499,33 +606,68 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
A, B = ctx.saved_tensors
with torch.no_grad():
A_grad = Matmul_ABT_2p5D.apply(B, output_grad, ctx.tesseract_dim, ctx.A_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
B_grad = Matmul_AB_2p5D.apply(A, output_grad, ctx.tesseract_dim, ctx.B_shape, ctx.row_rank, ctx.col_rank,
ctx.dep_rank, ctx.row_parallel_mode, ctx.col_parallel_mode,
ctx.data_parallel_rank, ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size, ctx.tensor_parallel_size)
A_grad = Matmul_ABT_2p5D.apply(
B,
output_grad,
ctx.tesseract_dim,
ctx.A_shape,
ctx.row_rank,
ctx.col_rank,
ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size,
)
B_grad = Matmul_AB_2p5D.apply(
A,
output_grad,
ctx.tesseract_dim,
ctx.B_shape,
ctx.row_rank,
ctx.col_rank,
ctx.dep_rank,
ctx.row_parallel_mode,
ctx.col_parallel_mode,
ctx.data_parallel_rank,
ctx.pipeline_parallel_rank,
ctx.pipeline_parallel_size,
ctx.tensor_parallel_size,
)
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class _Add_Bias_2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
row_rank: int, col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
def forward(
ctx: Any,
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
row_rank: int,
col_rank: int,
dep_rank: int,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
if row_rank == 0:
bias_temp = bias.clone()
else:
bias_temp = torch.zeros(output_size_per_partition, dtype=bias.dtype, device=get_current_device())
src_rank = \
col_rank + dep_rank * tesseract_dim ** 2 + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_rank = (
col_rank
+ dep_rank * tesseract_dim**2
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
dist.broadcast(bias_temp, src=src_rank, group=get_parallel_group(col_parallel_mode))
ctx.row_rank = row_rank
@@ -559,43 +701,120 @@ class _Add_Bias_2p5D(torch.autograd.Function):
tensor_parallel_size = ctx.tensor_parallel_size
if ctx.bias:
dst_rank = \
col_rank + dep_rank * (tesseract_dim ** 2) + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dst_rank = (
col_rank
+ dep_rank * (tesseract_dim**2)
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
dist.reduce(output_grad, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return \
None, output_grad, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
return (
None,
output_grad,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
else:
grad_tmp = torch.zeros_like(output_grad)
return \
None, grad_tmp, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
return (
None,
grad_tmp,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
else:
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = \
col_rank + dep_rank * (tesseract_dim ** 2) + \
data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dst_rank = (
col_rank
+ dep_rank * (tesseract_dim**2)
+ data_parallel_rank * pipeline_parallel_size * tensor_parallel_size
+ pipeline_parallel_rank * tensor_parallel_size
)
dist.reduce(reduce, dst=dst_rank, group=get_parallel_group(col_parallel_mode))
if row_rank == 0:
return \
output_grad, reduce, None, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None
return (
output_grad,
reduce,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
else:
reduce_tmp = torch.zeros_like(reduce)
return \
output_grad, reduce_tmp, None, None, None, None, None, None, \
None, None, None, None, None, None, None, None, None
return (
output_grad,
reduce_tmp,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
def add_bias_2p5d(
input: Tensor,
bias: Tensor,
output_size_per_partition: int,
tesseract_dim: int,
row_rank: int,
col_rank: int,
dep_rank: int,
col_parallel_mode: ParallelMode,
skip_bias_add: bool,
data_parallel_rank: int,
pipeline_parallel_rank: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
) -> Tensor:
r"""Matrix add bias: :math:`C = A + b`.
Args:
@@ -618,9 +837,21 @@ def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, t
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_
"""
return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank,
col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
return _Add_Bias_2p5D.apply(
input,
bias,
output_size_per_partition,
tesseract_dim,
row_rank,
col_rank,
dep_rank,
col_parallel_mode,
skip_bias_add,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size,
)
class _Layernorm2p5D(torch.autograd.Function):
@@ -640,8 +871,9 @@ class _Layernorm2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
row_parallel_mode: ParallelMode) -> Tensor:
def forward(
ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode
) -> Tensor:
input = input - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
ctx.hidden_size = hidden_size
@@ -673,8 +905,9 @@ class _Layernorm2p5D(torch.autograd.Function):
return input_grad, None, None, None, None, None, None
def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
row_parallel_mode: ParallelMode) -> Tensor:
def layernorm_2p5d(
input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode
) -> Tensor:
r"""Layernorm.
Args:
@@ -692,7 +925,6 @@ def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
class _AllGatherTensor2p5D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
@@ -753,9 +985,9 @@ class SplitFirst(torch.autograd.Function):
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size,) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.tesseract_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode))
dist.all_gather(
list(grad.chunk(ctx.tesseract_dim, dim=0)), output_grad.contiguous(), group=gpc.get_group(ctx.para_mode)
)
return grad, None, None
@@ -775,15 +1007,16 @@ def split_batch_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
if world_size <= 1:
return input_
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
assert (
dim_size % world_size == 0
), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})."
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL), dim=dim)[
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
].contiguous()
class _ReduceTensor2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@@ -808,7 +1041,6 @@ def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
class _ReduceScatterTensor2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
@@ -834,14 +1066,14 @@ def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: Parallel
"""
dim_size = input_.size(dim)
world_size = gpc.get_world_size(parallel_mode)
assert dim_size % world_size == 0, \
f'The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size}).'
assert (
dim_size % world_size == 0
), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})."
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)
class _RreduceByBatch2p5D(torch.autograd.Function):
@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)

View File

@@ -7,19 +7,24 @@ def get_tesseract_dim_dep_from_env():
try:
tesseract_dim = env.tesseract_dim
tesseract_dep = env.tesseract_dep
assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero'
assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero'
assert tesseract_dim > 0, "TESSERACT_DIM must be larger than zero"
assert tesseract_dep > 0, "TESSERACT_DEP must be larger than zero"
return tesseract_dim, tesseract_dep
except KeyError as e:
raise EnvironmentError('TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, '
'please make sure that you have used the correct process group initializer')
except KeyError:
raise EnvironmentError(
"TESSERACT_DIM or TESSERACT_DEP is not found in the current environment, "
"please make sure that you have used the correct process group initializer"
)
def assert_tesseract_initialization():
assert gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP) and \
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ), \
'Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ ' \
'must be initialized by the process group initializer'
assert (
gpc.is_initialized(ParallelMode.PARALLEL_2P5D_COL)
and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_ROW)
and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_DEP)
and gpc.is_initialized(ParallelMode.PARALLEL_2P5D_XZ)
), (
"Both PARALLEL_2P5D_COL, PARALLEL_2P5D_ROW, PARALLEL_2P5D_DEP and PARALLEL_2P5D_XZ "
"must be initialized by the process group initializer"
)

View File

@@ -56,14 +56,16 @@ class Linear2p5D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
@@ -82,15 +84,16 @@ class Linear2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(out_features, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
self.weight = Parameter(
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs))
torch.empty(self.input_size_per_partition, self.hidden_size_per_partition, **factory_kwargs)
)
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)
# initialize parameters
with seed(ParallelMode.TENSOR):
@@ -110,8 +113,8 @@ class Linear2p5D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -124,43 +127,33 @@ class Linear2p5D(ParallelLayer):
local_state[bias_key] = bias
# broadcast in dep groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0 and \
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0:
if (
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL) == 0
and gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0
):
broadcast_state_dict(local_state, ParallelMode.PARALLEL_2P5D_DEP)
# partition in column groups
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW) == 0:
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# partition in row groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0:
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
@@ -169,14 +162,8 @@ class Linear2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
# gather in column groups
@@ -184,14 +171,8 @@ class Linear2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -221,16 +202,38 @@ class Linear2p5D(ParallelLayer):
if self.bias is not None:
if self.skip_bias_add:
bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
bias = add_bias_2p5d(
None,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
return output, bias
else:
output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL,
False, self.data_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_size, self.tensor_parallel_size)
output = add_bias_2p5d(
output,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
return output
else:
return output
@@ -266,10 +269,10 @@ class LayerNorm2p5D(ParallelLayer):
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
self.partitioned_partition = divide(normalized_shape, self.tesseract_dim) # *
# create parameters
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
self.weight = Parameter(torch.ones(self.partitioned_partition, **factory_kwargs))
if bias:
@@ -286,8 +289,8 @@ class LayerNorm2p5D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -303,34 +306,22 @@ class LayerNorm2p5D(ParallelLayer):
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict({weight_key: self.weight})
if self.bias is not None:
local_state[bias_key] = self.bias
@@ -339,14 +330,8 @@ class LayerNorm2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
# gather in row groups
@@ -354,14 +339,8 @@ class LayerNorm2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -369,29 +348,51 @@ class LayerNorm2p5D(ParallelLayer):
def forward(self, x: Tensor) -> Tensor:
with torch.no_grad():
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
E_x = torch.sum(x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(E_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
E_x /= self.normalized_shape
# Var_x in the block below is the sum of input^2
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
Var_x = torch.sum(x * x, dim=-1, keepdim=True) # [b/q, s, 1]
torch.distributed.all_reduce(Var_x, group=gpc.get_group(ParallelMode.PARALLEL_2P5D_ROW))
Var_x /= self.normalized_shape
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
Var_x = Var_x - E_x * E_x # variance of x [b/q, s, 1]
# this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
scale = add_bias_2p5d(None, self.weight, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
scale = add_bias_2p5d(
None,
self.weight,
self.partitioned_partition,
self.tesseract_dim,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
if self.bias is not None:
bias = add_bias_2p5d(None, self.bias, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
bias = add_bias_2p5d(
None,
self.bias,
self.partitioned_partition,
self.tesseract_dim,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
True,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
output = torch.addcmul(bias, scale, output)
else:
output = torch.mul(scale, output)
@@ -420,16 +421,18 @@ class PatchEmbedding2p5D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_(),
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@@ -446,17 +449,22 @@ class PatchEmbedding2p5D(ParallelLayer):
with seed(ParallelMode.TENSOR):
self.weight = Parameter(
torch.empty((self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype))
torch.empty(
(self.embed_size_per_partition, in_chans, *self.patch_size),
device=get_current_device(),
dtype=dtype,
)
)
self.bias = Parameter(torch.empty(self.embed_size_per_partition, device=get_current_device(), dtype=dtype))
self.cls_token = Parameter(
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype))
torch.zeros((1, 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype)
)
self.pos_embed = Parameter(
torch.zeros((1, self.num_patches + 1, self.embed_size_per_partition),
device=get_current_device(),
dtype=dtype))
torch.zeros(
(1, self.num_patches + 1, self.embed_size_per_partition), device=get_current_device(), dtype=dtype
)
)
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
self._set_tensor_parallel_attribute()
@@ -477,10 +485,10 @@ class PatchEmbedding2p5D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
cls_token_key = prefix + "cls_token"
pos_embed_key = prefix + "pos_embed"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -504,67 +512,34 @@ class PatchEmbedding2p5D(ParallelLayer):
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
pos_embed_key = prefix + 'pos_embed'
local_state = OrderedDict({
weight_key: self.weight,
bias_key: self.bias,
cls_token_key: self.cls_token,
pos_embed_key: self.pos_embed
})
weight_key = prefix + "weight"
bias_key = prefix + "bias"
cls_token_key = prefix + "cls_token"
pos_embed_key = prefix + "pos_embed"
local_state = OrderedDict(
{weight_key: self.weight, bias_key: self.bias, cls_token_key: self.cls_token, pos_embed_key: self.pos_embed}
)
# gather in column groups
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
keep_vars=keep_vars,
)
# gather in row groups
@@ -572,18 +547,8 @@ class PatchEmbedding2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: 0,
bias_key: 0,
cls_token_key: -1,
pos_embed_key: -1
},
partition_states={
weight_key: True,
bias_key: True,
cls_token_key: True,
pos_embed_key: True
},
dims={weight_key: 0, bias_key: 0, cls_token_key: -1, pos_embed_key: -1},
partition_states={weight_key: True, bias_key: True, cls_token_key: True, pos_embed_key: True},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -593,15 +558,16 @@ class PatchEmbedding2p5D(ParallelLayer):
input_ = split_batch_2p5d(input_, 0)
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
@@ -643,14 +609,16 @@ class Embedding2p5D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs,
):
super().__init__()
assert_tesseract_initialization()
@@ -664,7 +632,8 @@ class Embedding2p5D(ParallelLayer):
self.embed_kwargs = kwargs
self.weight = Parameter(
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype)
)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
@@ -685,7 +654,7 @@ class Embedding2p5D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -711,7 +680,7 @@ class Embedding2p5D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
local_state = OrderedDict({weight_key: self.weight})
# gather in column groups
@@ -775,14 +744,16 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
@@ -799,9 +770,12 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
torch.empty(
(self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype,
)
)
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
@@ -817,14 +791,13 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None and \
self.vocab_start_index <= self.padding_idx < self.vocab_end_index:
if self.padding_idx is not None and self.vocab_start_index <= self.padding_idx < self.vocab_end_index:
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
weight = state_dict.pop(weight_key, None)
@@ -850,7 +823,7 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
weight_key = prefix + "weight"
local_state = OrderedDict({weight_key: self.weight})
# gather in column groups
@@ -880,11 +853,12 @@ class VocabParallelEmbedding2p5D(ParallelLayer):
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL)
return output
@@ -909,14 +883,16 @@ class Classifier2p5D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
@@ -934,7 +910,8 @@ class Classifier2p5D(ParallelLayer):
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype))
torch.empty(self.num_classes, self.input_size_per_partition, device=get_current_device(), dtype=dtype)
)
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
@@ -964,8 +941,8 @@ class Classifier2p5D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
@@ -983,34 +960,22 @@ class Classifier2p5D(ParallelLayer):
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
local_state = OrderedDict()
if self.has_weight:
local_state[weight_key] = self.weight
@@ -1021,14 +986,8 @@ class Classifier2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
# gather in row groups
@@ -1036,14 +995,8 @@ class Classifier2p5D(ParallelLayer):
local_state = gather_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: False
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: False},
keep_vars=keep_vars,
)
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@@ -1052,10 +1005,21 @@ class Classifier2p5D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes,)
return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
return classifier_2p5d(
input_,
self.weight,
self.bias,
self.tesseract_dim,
out_shape,
self.row_rank,
self.col_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
@LAYERS.register_module
@@ -1077,14 +1041,16 @@ class VocabParallelClassifier2p5D(ParallelLayer):
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
def __init__(
self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
super().__init__()
self.in_features = in_features
@@ -1102,13 +1068,14 @@ class VocabParallelClassifier2p5D(ParallelLayer):
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs))
torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs)
)
self.has_weight = True
# create bias, shape: [h/q]
if bias:
@@ -1137,8 +1104,8 @@ class VocabParallelClassifier2p5D(ParallelLayer):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
weight_key = prefix + "weight"
bias_key = prefix + "bias"
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
# weight
if self.has_weight:
@@ -1156,27 +1123,15 @@ class VocabParallelClassifier2p5D(ParallelLayer):
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_ROW,
dims={
weight_key: -1,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: -1, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
# partition in column groups
local_state = partition_tensor_parallel_state_dict(
local_state,
ParallelMode.PARALLEL_2P5D_COL,
dims={
weight_key: 0,
bias_key: 0
},
partition_states={
weight_key: True,
bias_key: True
},
dims={weight_key: 0, bias_key: 0},
partition_states={weight_key: True, bias_key: True},
)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
@@ -1203,8 +1158,19 @@ class VocabParallelClassifier2p5D(ParallelLayer):
)
if self.bias is not None:
output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = add_bias_2p5d(
output,
self.bias,
self.hidden_size_per_partition,
self.tesseract_dim,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_COL,
False,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
return output