fix layers/schedule for hybrid parallelization (#111) (#112)

This commit is contained in:
ver217
2022-01-04 20:52:31 +08:00
committed by GitHub
parent f03bcb359b
commit 7904baf6e1
6 changed files with 44 additions and 18 deletions

View File

@@ -133,7 +133,7 @@ class GPTBlock(CheckpointModule):
dtype: dtype = None,
bias: bool = True,
checkpoint: bool = False):
super().__init__()
super().__init__(checkpoint=checkpoint)
self.norm1 = col_nn.LayerNorm(normalized_shape=dim, eps=1e-6, dtype=dtype)
self.attn = GPTSelfAttention(dim=dim,
num_heads=num_heads,